mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
edaBits, ChaiGear, TopGear, CCD.
This commit is contained in:
@@ -233,6 +233,8 @@ public:
|
||||
template <class T>
|
||||
static void ands(T& processor, const vector<int>& args) { processor.ands(args); }
|
||||
template <class T>
|
||||
static void xors(T& processor, const vector<int>& args) { processor.xors(args); }
|
||||
template <class T>
|
||||
static void inputb(T& processor, const vector<int>& args) { processor.input(args); }
|
||||
template <class T>
|
||||
static T get_input(int from, GC::Processor<T>& processor, int n_bits)
|
||||
|
||||
18
BMR/common.h
18
BMR/common.h
@@ -11,6 +11,8 @@
|
||||
#include <vector>
|
||||
using namespace std;
|
||||
|
||||
#include "Tools/CheckVector.h"
|
||||
|
||||
typedef unsigned long wire_id_t;
|
||||
typedef unsigned long gate_id_t;
|
||||
typedef unsigned int party_id_t;
|
||||
@@ -37,20 +39,4 @@ public:
|
||||
bool call(bool left, bool right) { return rep[2 * left + right]; }
|
||||
};
|
||||
|
||||
template <class T>
|
||||
class CheckVector : public vector<T>
|
||||
{
|
||||
public:
|
||||
CheckVector() : vector<T>() {}
|
||||
CheckVector(size_t size) : vector<T>(size) {}
|
||||
CheckVector(size_t size, const T& def) : vector<T>(size, def) {}
|
||||
#ifdef CHECK_SIZE
|
||||
T& operator[](size_t i) { return this->at(i); }
|
||||
const T& operator[](size_t i) const { return this->at(i); }
|
||||
#else
|
||||
T& at(size_t i) { return (*this)[i]; }
|
||||
const T& at(size_t i) const { return (*this)[i]; }
|
||||
#endif
|
||||
};
|
||||
|
||||
#endif /* CIRCUIT_INC_COMMON_H_ */
|
||||
|
||||
12
CHANGELOG.md
12
CHANGELOG.md
@@ -1,5 +1,17 @@
|
||||
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
|
||||
|
||||
## 0.1.5 (Mar 20, 2020)
|
||||
|
||||
- Faster conversion between arithmetic and binary secret sharing using [extended daBits](https://eprint.iacr.org/2020/338)
|
||||
- Optimized daBits
|
||||
- Optimized logistic regression
|
||||
- Faster compilation of repetitive code (compiler option `-C`)
|
||||
- ChaiGear: [HighGear](https://eprint.iacr.org/2017/1230) with covert key generation
|
||||
- [TopGear](https://eprint.iacr.org/2019/035) zero-knowledge proofs
|
||||
- Binary computation based on Shamir secret sharing
|
||||
- Fixed security bug: Prove correctness of ciphertexts in input tuple generation
|
||||
- Fixed security bug: Missing check in MASCOT bit generation and various binary computations
|
||||
|
||||
## 0.1.4 (Dec 23, 2019)
|
||||
|
||||
- Mixed circuit computation with secret sharing
|
||||
|
||||
@@ -36,6 +36,8 @@ opcodes = dict(
|
||||
STMSBI = 0x243,
|
||||
MOVSB = 0x244,
|
||||
INPUTB = 0x246,
|
||||
SPLIT = 0x248,
|
||||
CONVCBIT2S = 0x249,
|
||||
XORCBI = 0x210,
|
||||
BITDECC = 0x211,
|
||||
CONVCINT = 0x213,
|
||||
@@ -49,15 +51,23 @@ opcodes = dict(
|
||||
MULCBI = 0x21c,
|
||||
SHRCBI = 0x21d,
|
||||
SHLCBI = 0x21e,
|
||||
CONVCINTVEC = 0x21f,
|
||||
PRINTREGSIGNED = 0x220,
|
||||
PRINTREGB = 0x221,
|
||||
PRINTREGPLAINB = 0x222,
|
||||
PRINTFLOATPLAINB = 0x223,
|
||||
CONDPRINTSTRB = 0x224,
|
||||
CONVCBIT = 0x230,
|
||||
CONVCBITVEC = 0x231,
|
||||
)
|
||||
|
||||
class xors(base.Instruction):
|
||||
class BinaryVectorInstruction(base.Instruction):
|
||||
is_vec = lambda self: True
|
||||
|
||||
def copy(self, size, subs):
|
||||
return type(self)(*self.get_new_args(size, subs))
|
||||
|
||||
class xors(BinaryVectorInstruction):
|
||||
code = opcodes['XORS']
|
||||
arg_format = tools.cycle(['int','sbw','sb','sb'])
|
||||
|
||||
@@ -73,15 +83,21 @@ class xorcbi(base.Instruction):
|
||||
code = opcodes['XORCBI']
|
||||
arg_format = ['cbw','cb','int']
|
||||
|
||||
class andrs(base.Instruction):
|
||||
class andrs(BinaryVectorInstruction):
|
||||
code = opcodes['ANDRS']
|
||||
arg_format = tools.cycle(['int','sbw','sb','sb'])
|
||||
|
||||
class ands(base.Instruction):
|
||||
def add_usage(self, req_node):
|
||||
req_node.increment(('bit', 'triple'), sum(self.args[::4]))
|
||||
|
||||
class ands(BinaryVectorInstruction):
|
||||
code = opcodes['ANDS']
|
||||
arg_format = tools.cycle(['int','sbw','sb','sb'])
|
||||
|
||||
class andm(base.Instruction):
|
||||
def add_usage(self, req_node):
|
||||
req_node.increment(('bit', 'triple'), sum(self.args[::4]))
|
||||
|
||||
class andm(BinaryVectorInstruction):
|
||||
code = opcodes['ANDM']
|
||||
arg_format = ['int','sbw','sb','cb']
|
||||
|
||||
@@ -181,6 +197,31 @@ class convcbit(base.Instruction):
|
||||
code = opcodes['CONVCBIT']
|
||||
arg_format = ['ciw','cb']
|
||||
|
||||
@base.vectorize
|
||||
class convcintvec(base.Instruction):
|
||||
code = opcodes['CONVCINTVEC']
|
||||
arg_format = tools.chain(['c'], tools.cycle(['cbw']))
|
||||
|
||||
class convcbitvec(BinaryVectorInstruction):
|
||||
code = opcodes['CONVCBITVEC']
|
||||
arg_format = ['int','ciw','cb']
|
||||
def __init__(self, *args):
|
||||
super(convcbitvec, self).__init__(*args)
|
||||
assert(args[2].n == args[0])
|
||||
args[1].set_size(args[0])
|
||||
|
||||
class convcbit2s(BinaryVectorInstruction):
|
||||
code = opcodes['CONVCBIT2S']
|
||||
arg_format = ['int','sbw','cb']
|
||||
|
||||
@base.vectorize
|
||||
class split(base.Instruction):
|
||||
code = opcodes['SPLIT']
|
||||
arg_format = tools.chain(['int','s'], tools.cycle(['sbw']))
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(split_class, self).__init__(*args, **kwargs)
|
||||
assert (len(args) - 2) % args[0] == 0
|
||||
|
||||
class movsb(base.Instruction):
|
||||
code = opcodes['MOVSB']
|
||||
arg_format = ['sbw','sb']
|
||||
@@ -196,9 +237,9 @@ class bitb(base.Instruction):
|
||||
code = opcodes['BITB']
|
||||
arg_format = ['sbw']
|
||||
|
||||
class reveal(base.Instruction):
|
||||
class reveal(BinaryVectorInstruction, base.VarArgsInstruction, base.Mergeable):
|
||||
code = opcodes['REVEAL']
|
||||
arg_format = ['int','cbw','sb']
|
||||
arg_format = tools.cycle(['int','cbw','sb'])
|
||||
|
||||
class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction):
|
||||
__slots__ = []
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
from Compiler.types import MemValue, read_mem_value, regint, Array, cint
|
||||
from Compiler.types import _bitint, _number, _fix, _structure, _bit
|
||||
from Compiler.types import _bitint, _number, _fix, _structure, _bit, _vec, sint
|
||||
from Compiler.program import Tape, Program
|
||||
from Compiler.exceptions import *
|
||||
from Compiler import util, oram, floatingpoint, library
|
||||
from Compiler import instructions_base
|
||||
import Compiler.GC.instructions as inst
|
||||
import operator
|
||||
import math
|
||||
from functools import reduce
|
||||
|
||||
class bits(Tape.Register, _structure, _bit):
|
||||
n = 40
|
||||
size = 1
|
||||
unit = 64
|
||||
PreOp = staticmethod(floatingpoint.PreOpN)
|
||||
decomposed = None
|
||||
@staticmethod
|
||||
@@ -19,9 +21,7 @@ class bits(Tape.Register, _structure, _bit):
|
||||
[1 - x for x in l])]
|
||||
@classmethod
|
||||
def get_type(cls, length):
|
||||
if length is None:
|
||||
return cls
|
||||
elif length == 1:
|
||||
if length == 1:
|
||||
return cls.bit_type
|
||||
if length not in cls.types:
|
||||
class bitsn(cls):
|
||||
@@ -65,6 +65,11 @@ class bits(Tape.Register, _structure, _bit):
|
||||
return res + suffix
|
||||
else:
|
||||
return self.decomposed[:n] + suffix
|
||||
@staticmethod
|
||||
def bit_decompose_clear(a, n_bits):
|
||||
res = [cbits.get_type(a.size)() for i in range(n_bits)]
|
||||
cbits.conv_cint_vec(a, *res)
|
||||
return res
|
||||
@classmethod
|
||||
def malloc(cls, size):
|
||||
return Program.prog.malloc(size, cls)
|
||||
@@ -87,41 +92,61 @@ class bits(Tape.Register, _structure, _bit):
|
||||
def __init__(self, value=None, n=None, size=None):
|
||||
if size != 1 and size is not None:
|
||||
raise Exception('invalid size for bit type: %s' % size)
|
||||
Tape.Register.__init__(self, self.reg_type, Program.prog.curr_tape)
|
||||
self.set_length(n or self.n)
|
||||
self.n = n or self.n
|
||||
size = math.ceil(self.n / self.unit) if self.n != None else None
|
||||
Tape.Register.__init__(self, self.reg_type, Program.prog.curr_tape,
|
||||
size=size)
|
||||
if value is not None:
|
||||
self.load_other(value)
|
||||
def copy(self):
|
||||
return type(self)(n=instructions_base.get_global_vector_size())
|
||||
def set_length(self, n):
|
||||
if n > self.max_length:
|
||||
print(self.max_length)
|
||||
raise Exception('too long: %d' % n)
|
||||
self.n = n
|
||||
def set_size(self, size):
|
||||
pass
|
||||
def load_other(self, other):
|
||||
if isinstance(other, cint):
|
||||
size = other.size
|
||||
other = sum(x << i for i, x in enumerate(other))
|
||||
other = other.to_regint(size)
|
||||
if isinstance(other, int):
|
||||
assert(self.n == other.size)
|
||||
self.conv_regint_by_bit(self.n, self, other.to_regint(1))
|
||||
elif isinstance(other, int):
|
||||
self.set_length(self.n or util.int_len(other))
|
||||
self.load_int(other)
|
||||
elif isinstance(other, regint):
|
||||
assert(other.size == 1)
|
||||
self.conv_regint(self.n, self, other)
|
||||
assert(other.size == math.ceil(self.n / self.unit))
|
||||
for i, (x, y) in enumerate(zip(self, other)):
|
||||
self.conv_regint(min(self.unit, self.n - i * self.unit), x, y)
|
||||
elif isinstance(self, type(other)) or isinstance(other, type(self)):
|
||||
self.mov(self, other)
|
||||
assert(self.n == other.n)
|
||||
for i in range(math.ceil(self.n / self.unit)):
|
||||
self.mov(self[i], other[i])
|
||||
else:
|
||||
try:
|
||||
other = self.bit_compose(other.bit_decompose())
|
||||
self.mov(self, other)
|
||||
self.load_other(other)
|
||||
except:
|
||||
raise CompilerError('cannot convert from %s to %s' % \
|
||||
(type(other), type(self)))
|
||||
def long_one(self):
|
||||
return 2**self.n - 1
|
||||
return 2**self.n - 1 if self.n != None else None
|
||||
def __repr__(self):
|
||||
return '%s(%d/%d)' % \
|
||||
(super(bits, self).__repr__(), self.n, type(self).n)
|
||||
if self.n != None:
|
||||
suffix = '%d' % self.n
|
||||
if type(self).n != None and type(self).n != self.n:
|
||||
suffice += '/%d' % type(self).n
|
||||
else:
|
||||
suffix = 'undef'
|
||||
return '%s(%s)' % (super(bits, self).__repr__(), suffix)
|
||||
__str__ = __repr__
|
||||
def _new_by_number(self, i, size=1):
|
||||
assert(size == 1)
|
||||
n = min(self.unit, self.n - (i - self.i) * self.unit)
|
||||
res = self.get_type(n)()
|
||||
res.i = i
|
||||
res.program = self.program
|
||||
return res
|
||||
|
||||
class cbits(bits):
|
||||
max_length = 64
|
||||
@@ -131,6 +156,12 @@ class cbits(bits):
|
||||
store_inst = (None, inst.stmcb)
|
||||
bitdec = inst.bitdecc
|
||||
conv_regint = staticmethod(lambda n, x, y: inst.convcint(x, y))
|
||||
conv_cint_vec = inst.convcintvec
|
||||
@classmethod
|
||||
def conv_regint_by_bit(cls, n, res, other):
|
||||
assert n == res.n
|
||||
assert n == other.size
|
||||
cls.conv_cint_vec(cint(other, size=other.size), res)
|
||||
types = {}
|
||||
def load_int(self, value):
|
||||
self.load_other(regint(value))
|
||||
@@ -187,6 +218,13 @@ class cbits(bits):
|
||||
if self.n > 64:
|
||||
raise CompilerError('too many bits')
|
||||
inst.convcbit(dest, self)
|
||||
def to_regint_by_bit(self):
|
||||
if self.n != None:
|
||||
res = regint(size=self.n)
|
||||
else:
|
||||
res = regint()
|
||||
inst.convcbitvec(self.n, res, self)
|
||||
return res
|
||||
|
||||
class sbits(bits):
|
||||
max_length = 128
|
||||
@@ -199,6 +237,11 @@ class sbits(bits):
|
||||
bitdec = inst.bitdecs
|
||||
bitcom = inst.bitcoms
|
||||
conv_regint = inst.convsint
|
||||
@classmethod
|
||||
def conv_regint_by_bit(cls, n, res, other):
|
||||
tmp = cbits.get_type(n)()
|
||||
tmp.conv_regint_by_bit(n, tmp, other)
|
||||
res.load_other(tmp)
|
||||
mov = inst.movsb
|
||||
types = {}
|
||||
def __init__(self, *args, **kwargs):
|
||||
@@ -250,18 +293,26 @@ class sbits(bits):
|
||||
self.mov(self, lower + (upper << 64))
|
||||
else:
|
||||
raise NotImplementedError('more than 128 bits wanted')
|
||||
def load_other(self, other):
|
||||
if isinstance(other, cbits) and self.n == other.n:
|
||||
inst.convcbit2s(self.n, self, other)
|
||||
else:
|
||||
super(sbits, self).load_other(other)
|
||||
@read_mem_value
|
||||
def __add__(self, other):
|
||||
if isinstance(other, int):
|
||||
if isinstance(other, int) or other is None:
|
||||
return self.xor_int(other)
|
||||
else:
|
||||
if not isinstance(other, sbits):
|
||||
other = self.conv(other)
|
||||
n = min(self.n, other.n)
|
||||
if self.n is None or other.n is None:
|
||||
assert self.n == other.n
|
||||
n = None
|
||||
else:
|
||||
n = min(self.n, other.n)
|
||||
res = self.new(n=n)
|
||||
inst.xors(n, res, self, other)
|
||||
max_n = max(self.n, other.n)
|
||||
if max_n > n:
|
||||
if self.n != None and max(self.n, other.n) > n:
|
||||
if self.n > n:
|
||||
longer = self
|
||||
else:
|
||||
@@ -293,17 +344,13 @@ class sbits(bits):
|
||||
return res
|
||||
except AttributeError:
|
||||
return NotImplemented
|
||||
@read_mem_value
|
||||
def __rmul__(self, other):
|
||||
if isinstance(other, cbits):
|
||||
return other * self
|
||||
else:
|
||||
return self.mul_int(other)
|
||||
__rmul__ = __mul__
|
||||
@read_mem_value
|
||||
def __and__(self, other):
|
||||
if util.is_zero(other):
|
||||
return 0
|
||||
elif util.is_all_ones(other, self.n):
|
||||
elif util.is_all_ones(other, self.n) or \
|
||||
(other is None and self.n == None):
|
||||
return self
|
||||
res = self.new(n=self.n)
|
||||
if not isinstance(other, sbits):
|
||||
@@ -314,6 +361,7 @@ class sbits(bits):
|
||||
assert(self.n == other.n)
|
||||
inst.ands(self.n, res, self, other)
|
||||
return res
|
||||
__rand__ = __and__
|
||||
def xor_int(self, other):
|
||||
if other == 0:
|
||||
return self
|
||||
@@ -326,6 +374,7 @@ class sbits(bits):
|
||||
for x,y in zip(self_bits, other_bits)] \
|
||||
+ extra_bits)
|
||||
def mul_int(self, other):
|
||||
assert(util.is_constant(other))
|
||||
if other == 0:
|
||||
return 0
|
||||
elif other == 1:
|
||||
@@ -344,14 +393,19 @@ class sbits(bits):
|
||||
# res = type(self)(n=self.n)
|
||||
# inst.nots(res, self)
|
||||
# return res
|
||||
one = self.new(value=self.long_one(), n=self.n)
|
||||
if self.n == None or self.n > self.unit:
|
||||
one = self.get_type(self.n)()
|
||||
self.conv_regint_by_bit(self.n, one, regint(1, size=self.n))
|
||||
else:
|
||||
one = self.new(value=self.long_one(), n=self.n)
|
||||
return self + one
|
||||
def __neg__(self):
|
||||
return self
|
||||
def reveal(self):
|
||||
if self.n > self.clear_type.max_length:
|
||||
raise Exception('too long to reveal')
|
||||
res = self.clear_type(n=self.n)
|
||||
if self.n == None or \
|
||||
self.n > max(self.max_length, self.clear_type.max_length):
|
||||
assert(self.unit == self.clear_type.unit)
|
||||
res = self.clear_type.get_type(self.n)()
|
||||
inst.reveal(self.n, res, self)
|
||||
return res
|
||||
def equal(self, other, n=None):
|
||||
@@ -395,8 +449,11 @@ class sbits(bits):
|
||||
@staticmethod
|
||||
def bit_adder(*args, **kwargs):
|
||||
return sbitint.bit_adder(*args, **kwargs)
|
||||
@staticmethod
|
||||
def ripple_carry_adder(*args, **kwargs):
|
||||
return sbitint.ripple_carry_adder(*args, **kwargs)
|
||||
|
||||
class sbitvec(object):
|
||||
class sbitvec(_vec):
|
||||
@classmethod
|
||||
def get_type(cls, n):
|
||||
return cls
|
||||
@@ -414,8 +471,27 @@ class sbitvec(object):
|
||||
def from_matrix(cls, matrix):
|
||||
# any number of rows, limited number of columns
|
||||
return cls.combine(cls(row) for row in matrix)
|
||||
def __init__(self, elements=None):
|
||||
if elements is not None:
|
||||
def __init__(self, elements=None, length=None):
|
||||
if length:
|
||||
assert isinstance(elements, sint)
|
||||
if Program.prog.use_split():
|
||||
n = Program.prog.use_split()
|
||||
columns = [[sbits.get_type(elements.size)()
|
||||
for i in range(n)] for i in range(length)]
|
||||
inst.split(n, elements, *sum(columns, []))
|
||||
x = sbitint.wallace_tree_without_finish(columns, False)
|
||||
v = sbitint.carry_lookahead_adder(x[0], x[1], fewer_inv=True)
|
||||
else:
|
||||
assert Program.prog.options.ring
|
||||
l = int(Program.prog.options.ring)
|
||||
r, r_bits = sint.get_edabit(length, size=elements.size)
|
||||
c = ((elements - r) << (l - length)).reveal()
|
||||
c >>= l - length
|
||||
cb = [(c >> i) for i in range(length)]
|
||||
x = sbitintvec.from_vec(r_bits) + sbitintvec.from_vec(cb)
|
||||
v = x.v
|
||||
self.v = v[:length]
|
||||
elif elements is not None:
|
||||
self.v = sbits.trans(elements)
|
||||
def popcnt(self):
|
||||
res = sbitint.wallace_tree([[b] for b in self.v])
|
||||
@@ -426,8 +502,15 @@ class sbitvec(object):
|
||||
if stop is None:
|
||||
start, stop = stop, start
|
||||
return sbits.trans(self.v[start:stop])
|
||||
def coerce(self, other):
|
||||
if isinstance(other, cint):
|
||||
size = other.size
|
||||
return (other.get_vector(base, min(64, size - base)) \
|
||||
for base in range(0, size, 64))
|
||||
return other
|
||||
def __xor__(self, other):
|
||||
return self.from_vec(x ^ y for x, y in zip(self.v, other.v))
|
||||
other = self.coerce(other)
|
||||
return self.from_vec(x ^ y for x, y in zip(self.v, other))
|
||||
def __and__(self, other):
|
||||
return self.from_vec(x & y for x, y in zip(self.v, other.v))
|
||||
def if_else(self, x, y):
|
||||
@@ -453,6 +536,26 @@ class sbitvec(object):
|
||||
def bit_decompose(self):
|
||||
return self.v
|
||||
bit_compose = from_vec
|
||||
def reveal(self):
|
||||
assert len(self) == 1
|
||||
return self.v[0].reveal()
|
||||
def long_one(self):
|
||||
return [x.long_one() for x in self.v]
|
||||
def __rsub__(self, other):
|
||||
return self.from_vec(y - x for x, y in zip(self.v, other))
|
||||
def half_adder(self, other):
|
||||
other = self.coerce(other)
|
||||
res = zip(*(x.half_adder(y) for x, y in zip(self.v, other)))
|
||||
return (self.from_vec(x) for x in res)
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, int):
|
||||
return self.from_vec(x * other for x in self.v)
|
||||
def __add__(self, other):
|
||||
return self.from_vec(x + y for x, y in zip(self.v, other))
|
||||
def bit_and(self, other):
|
||||
return self & other
|
||||
def bit_xor(self, other):
|
||||
return self ^ other
|
||||
|
||||
class bit(object):
|
||||
n = 1
|
||||
|
||||
@@ -16,12 +16,13 @@ from functools import reduce
|
||||
class StraightlineAllocator:
|
||||
"""Allocate variables in a straightline program using n registers.
|
||||
It is based on the precondition that every register is only defined once."""
|
||||
def __init__(self, n):
|
||||
def __init__(self, n, program):
|
||||
self.alloc = dict_by_id()
|
||||
self.usage = Compiler.program.RegType.create_dict(lambda: 0)
|
||||
self.defined = dict_by_id()
|
||||
self.dealloc = set_by_id()
|
||||
self.n = n
|
||||
self.program = program
|
||||
|
||||
def alloc_reg(self, reg, free):
|
||||
base = reg.vectorbase
|
||||
@@ -76,7 +77,8 @@ class StraightlineAllocator:
|
||||
# unused register
|
||||
self.alloc_reg(j, alloc_pool)
|
||||
unused_regs.append(j)
|
||||
if unused_regs and len(unused_regs) == len(list(i.get_def())):
|
||||
if unused_regs and len(unused_regs) == len(list(i.get_def())) and \
|
||||
self.program.verbose:
|
||||
# only report if all assigned registers are unused
|
||||
print("Register(s) %s never used, assigned by '%s' in %s" % \
|
||||
(unused_regs,i,format_trace(i.caller)))
|
||||
@@ -175,37 +177,8 @@ class Merger:
|
||||
except StopIteration:
|
||||
return mergecount, None
|
||||
|
||||
def expand_vector_args(inst):
|
||||
if inst.is_vec():
|
||||
for arg in inst.args:
|
||||
arg.create_vector_elements()
|
||||
res = sum(list(zip(*inst.args)), ())
|
||||
return list(res)
|
||||
else:
|
||||
return inst.args
|
||||
|
||||
for i in merges_iter:
|
||||
if isinstance(instructions[n], startinput_class):
|
||||
instructions[n].args[1] += instructions[i].args[1]
|
||||
elif isinstance(instructions[n], (stopinput, gstopinput)):
|
||||
if instructions[n].get_size() != instructions[i].get_size():
|
||||
raise NotImplemented()
|
||||
else:
|
||||
instructions[n].args += instructions[i].args[1:]
|
||||
else:
|
||||
if instructions[n].get_size() != instructions[i].get_size():
|
||||
# merge as non-vector instruction
|
||||
instructions[n].args = expand_vector_args(instructions[n]) + \
|
||||
expand_vector_args(instructions[i])
|
||||
if instructions[n].is_vec():
|
||||
instructions[n].size = 1
|
||||
else:
|
||||
instructions[n].args += instructions[i].args
|
||||
|
||||
# join arg_formats if not special iterators
|
||||
# if not isinstance(instructions[n].arg_format, (itertools.repeat, itertools.cycle)) and \
|
||||
# not isinstance(instructions[i].arg_format, (itertools.repeat, itertools.cycle)):
|
||||
# instructions[n].arg_format += instructions[i].arg_format
|
||||
instructions[n].merge(instructions[i])
|
||||
instructions[i] = None
|
||||
self.merge_nodes(n, i)
|
||||
mergecount += 1
|
||||
@@ -343,7 +316,7 @@ class Merger:
|
||||
merge = merges[i]
|
||||
t = type(self.instructions[merge[0]])
|
||||
self.counter[t] += len(merge)
|
||||
if len(merge) > 1000:
|
||||
if len(merge) > 10000:
|
||||
print('Merging %d %s in round %d/%d' % \
|
||||
(len(merge), t.__name__, i, len(merges)))
|
||||
self.do_merge(merge)
|
||||
@@ -504,7 +477,8 @@ class Merger:
|
||||
next_available_depth[type(instr), d] = depth
|
||||
|
||||
round_type[depth] = instr.merge_id()
|
||||
parallel_open[depth] += len(instr.args) * instr.get_size()
|
||||
if int(options.max_parallel_open) > 0:
|
||||
parallel_open[depth] += len(instr.args) * instr.get_size()
|
||||
depths[n] = depth
|
||||
|
||||
if isinstance(instr, ReadMemoryInstruction):
|
||||
@@ -557,8 +531,9 @@ class Merger:
|
||||
print("Processed dependency of %d/%d instructions at" % \
|
||||
(n, len(block.instructions)), time.asctime())
|
||||
|
||||
if len(open_nodes) > 1000:
|
||||
print("Program has %d %s instructions" % (len(open_nodes), merge_classes))
|
||||
if len(open_nodes) > 1000 and self.block.parent.program.verbose:
|
||||
print("Basic block has %d %s instructions" %
|
||||
(len(open_nodes), merge_classes))
|
||||
|
||||
def merge_nodes(self, i, j):
|
||||
""" Merge node j into i, removing node j """
|
||||
@@ -608,7 +583,7 @@ class Merger:
|
||||
eliminate(list(G.pred[i])[0])
|
||||
eliminate(i)
|
||||
count += 2
|
||||
if count > 0:
|
||||
if count > 0 and self.block.parent.program.verbose:
|
||||
print('Eliminated %d dead instructions, among which %d opens: %s' \
|
||||
% (count, open_count, dict(stats)))
|
||||
|
||||
|
||||
@@ -69,13 +69,31 @@ def divide_by_two(res, x, m=1):
|
||||
inv2m(tmp, m)
|
||||
mulc(res, x, tmp)
|
||||
|
||||
@instructions_base.cisc
|
||||
def LTZ(s, a, k, kappa):
|
||||
"""
|
||||
s = (a ?< 0)
|
||||
|
||||
k: bit length of a
|
||||
"""
|
||||
from .types import sint
|
||||
from .types import sint, _bitint
|
||||
from .GC.types import sbitvec
|
||||
if program.use_split():
|
||||
movs(s, sint.conv(sbitvec(a, k).v[-1]))
|
||||
return
|
||||
elif program.options.ring:
|
||||
from . import floatingpoint
|
||||
assert(int(program.options.ring) >= k)
|
||||
m = k - 1
|
||||
shift = int(program.options.ring) - k
|
||||
r_prime, r_bin = MaskingBitsInRing(k)
|
||||
tmp = a - r_prime
|
||||
c_prime = (tmp << shift).reveal() >> shift
|
||||
a = r_bin[0].bit_decompose_clear(c_prime, m)
|
||||
b = r_bin[:m]
|
||||
u = CarryOutRaw(a[::-1], b[::-1])
|
||||
movs(s, sint.conv(r_bin[m].bit_xor(c_prime >> m).bit_xor(u)))
|
||||
return
|
||||
t = sint()
|
||||
Trunc(t, a, k, k - 1, kappa, True)
|
||||
subsfi(s, t, 0)
|
||||
@@ -86,6 +104,7 @@ def LessThanZero(a, k, kappa):
|
||||
LTZ(res, a, k, kappa)
|
||||
return res
|
||||
|
||||
@instructions_base.cisc
|
||||
def Trunc(d, a, k, m, kappa, signed):
|
||||
"""
|
||||
d = a >> m
|
||||
@@ -120,7 +139,7 @@ def TruncRing(d, a, k, m, signed):
|
||||
movs(d, res)
|
||||
return res
|
||||
|
||||
def TruncZeroes(a, k, m, signed):
|
||||
def TruncZeros(a, k, m, signed):
|
||||
if program.options.ring:
|
||||
return TruncLeakyInRing(a, k, m, signed)
|
||||
else:
|
||||
@@ -139,9 +158,8 @@ def TruncLeakyInRing(a, k, m, signed):
|
||||
from .types import sint, intbitint, cint, cgf2n
|
||||
n_bits = k - m
|
||||
n_shift = int(program.options.ring) - n_bits
|
||||
if program.use_dabit and n_bits > 1:
|
||||
r, r_bits = zip(*(sint.get_dabit() for i in range(n_bits)))
|
||||
r = sint.bit_compose(r)
|
||||
if n_bits > 1:
|
||||
r, r_bits = MaskingBitsInRing(n_bits, True)
|
||||
else:
|
||||
r_bits = [sint.get_random_bit() for i in range(n_bits)]
|
||||
r = sint.bit_compose(r_bits)
|
||||
@@ -150,7 +168,7 @@ def TruncLeakyInRing(a, k, m, signed):
|
||||
shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal()
|
||||
masked = shifted >> n_shift
|
||||
u = sint()
|
||||
BitLTL(u, masked, r_bits, 0)
|
||||
BitLTL(u, masked, r_bits[:n_bits], 0)
|
||||
res = (u << n_bits) + masked - r
|
||||
if signed:
|
||||
res -= (1 << (n_bits - 1))
|
||||
@@ -174,6 +192,7 @@ def TruncRoundNearest(a, k, m, kappa, signed=False):
|
||||
Trunc(res, a + (1 << (m - 1)), k + 1, m, kappa, signed)
|
||||
return res
|
||||
|
||||
@instructions_base.cisc
|
||||
def Mod2m(a_prime, a, k, m, kappa, signed):
|
||||
"""
|
||||
a_prime = a % 2^m
|
||||
@@ -199,16 +218,11 @@ def Mod2mRing(a_prime, a, k, m, signed):
|
||||
assert(int(program.options.ring) >= k)
|
||||
from Compiler.types import sint, intbitint, cint
|
||||
shift = int(program.options.ring) - m
|
||||
if program.use_dabit:
|
||||
r, r_bin = zip(*(sint.get_dabit() for i in range(m)))
|
||||
else:
|
||||
r = [sint.get_random_bit() for i in range(m)]
|
||||
r_bin = r
|
||||
r_prime = sint.bit_compose(r)
|
||||
r_prime, r_bin = MaskingBitsInRing(m, True)
|
||||
tmp = a + r_prime
|
||||
c_prime = (tmp << shift).reveal() >> shift
|
||||
u = sint()
|
||||
BitLTL(u, c_prime, r_bin, 0)
|
||||
BitLTL(u, c_prime, r_bin[:m], 0)
|
||||
res = (u << m) + c_prime - r_prime
|
||||
if a_prime is not None:
|
||||
movs(a_prime, res)
|
||||
@@ -247,19 +261,35 @@ def Mod2mField(a_prime, a, k, m, kappa, signed):
|
||||
adds(a_prime, t[5], t[4])
|
||||
return r_dprime, r_prime, c, c_prime, u, t, c2k1
|
||||
|
||||
def MaskingBitsInRing(m, strict=False):
|
||||
from Compiler.types import sint
|
||||
if program.use_edabit():
|
||||
return sint.get_edabit(m, strict)
|
||||
elif program.use_dabit:
|
||||
r, r_bin = zip(*(sint.get_dabit() for i in range(m)))
|
||||
else:
|
||||
r = [sint.get_random_bit() for i in range(m)]
|
||||
r_bin = r
|
||||
return sint.bit_compose(r), r_bin
|
||||
|
||||
def PRandM(r_dprime, r_prime, b, k, m, kappa, use_dabit=True):
|
||||
"""
|
||||
r_dprime = random secret integer in range [0, 2^(k + kappa - m) - 1]
|
||||
r_prime = random secret integer in range [0, 2^m - 1]
|
||||
b = array containing bits of r_prime
|
||||
"""
|
||||
program.curr_tape.require_bit_length(k + kappa)
|
||||
from .types import sint
|
||||
if program.use_edabit() and m > 1:
|
||||
movs(r_dprime, sint.get_edabit(k + kappa - m, True)[0])
|
||||
tmp, b[:] = sint.get_edabit(m, True)
|
||||
movs(r_prime, tmp)
|
||||
return
|
||||
t = [[program.curr_block.new_reg('s') for j in range(2)] for i in range(m)]
|
||||
t[0][1] = b[-1]
|
||||
PRandInt(r_dprime, k + kappa - m)
|
||||
# r_dprime is always multiplied by 2^m
|
||||
program.curr_tape.require_bit_length(k + kappa)
|
||||
if use_dabit and program.use_dabit and m > 1:
|
||||
from .types import sint
|
||||
r, b[:] = zip(*(sint.get_dabit() for i in range(m)))
|
||||
r = sint.bit_compose(r)
|
||||
movs(r_prime, r)
|
||||
@@ -389,17 +419,25 @@ def CarryOut(res, a, b, c=0, kappa=None):
|
||||
b: array of secret bits (same length as a)
|
||||
c: initial carry-in bit
|
||||
"""
|
||||
from .types import sint
|
||||
movs(res, sint.conv(CarryOutRaw(a, b, c)))
|
||||
|
||||
def CarryOutRaw(a, b, c=0):
|
||||
assert len(a) == len(b)
|
||||
k = len(a)
|
||||
from . import types
|
||||
d = [program.curr_block.new_reg('s') for i in range(k)]
|
||||
s = [program.curr_block.new_reg('s') for i in range(3)]
|
||||
for i in range(k):
|
||||
d[i] = list(b[i].half_adder(a[i]))
|
||||
s[0] = d[-1][0] * c
|
||||
s[0] = d[-1][0].bit_and(c)
|
||||
s[1] = d[-1][1] + s[0]
|
||||
d[-1][1] = s[1]
|
||||
|
||||
movs(res, types.sint.conv(CarryOutAux(d[::-1], kappa)))
|
||||
return CarryOutAux(d[::-1], None)
|
||||
|
||||
def CarryOutRawLE(a, b, c=0):
|
||||
""" Little-endian version """
|
||||
return CarryOutRaw(a[::-1], b[::-1], c)
|
||||
|
||||
def CarryOutLE(a, b, c=0):
|
||||
""" Little-endian version """
|
||||
@@ -416,13 +454,12 @@ def BitLTL(res, a, b, kappa):
|
||||
b: array of secret bits (same length as a)
|
||||
"""
|
||||
k = len(b)
|
||||
from . import floatingpoint
|
||||
a_bits = floatingpoint.bits(a, k)
|
||||
a_bits = b[0].bit_decompose_clear(a, k)
|
||||
s = [[program.curr_block.new_reg('s') for i in range(k)] for j in range(2)]
|
||||
t = [program.curr_block.new_reg('s') for i in range(1)]
|
||||
for i in range(len(b)):
|
||||
s[0][i] = b[0].long_one() - b[i]
|
||||
CarryOut(t[0], a_bits[::-1], s[0][::-1], 1, kappa)
|
||||
CarryOut(t[0], a_bits[::-1], s[0][::-1], b[0].long_one(), kappa)
|
||||
subsfi(res, t[0], 1)
|
||||
return a_bits, s[0]
|
||||
|
||||
|
||||
@@ -71,9 +71,10 @@ def run(args, options, param=-1, merge_opens=True, emulate=True, \
|
||||
|
||||
if prog.main_thread_running:
|
||||
prog.update_req(prog.curr_tape)
|
||||
print('Program requires:', repr(prog.req_num))
|
||||
print('Cost:', 0 if prog.req_num is None else prog.req_num.cost())
|
||||
print('Memory size:', dict(prog.allocated_mem))
|
||||
if prog.verbose:
|
||||
print('Program requires:', repr(prog.req_num))
|
||||
print('Cost:', 0 if prog.req_num is None else prog.req_num.cost())
|
||||
print('Memory size:', dict(prog.allocated_mem))
|
||||
|
||||
# finalize the memory
|
||||
prog.finalize_memory()
|
||||
|
||||
@@ -17,22 +17,16 @@ P_VALUES = { 32: 2147565569, \
|
||||
|
||||
P_VALUES[-1] = P_VALUES[128]
|
||||
|
||||
BIT_LENGTHS = { -1: 32,
|
||||
BIT_LENGTHS = { -1: 64,
|
||||
32: 16,
|
||||
64: 16,
|
||||
128: 64,
|
||||
256: 64,
|
||||
512: 64 }
|
||||
|
||||
STAT_SEC = { -1: 6,
|
||||
32: 6,
|
||||
64: 30,
|
||||
128: 40,
|
||||
256: 40,
|
||||
512: 40 }
|
||||
|
||||
|
||||
COST = { 'modp': defaultdict(lambda: 0,
|
||||
COST = defaultdict(lambda: defaultdict(lambda: 0),
|
||||
{ 'modp': defaultdict(lambda: 0,
|
||||
{ 'triple': 0.00020652622883106154,
|
||||
'square': 0.00020652622883106154,
|
||||
'bit': 0.00020652622883106154,
|
||||
@@ -51,7 +45,7 @@ COST = { 'modp': defaultdict(lambda: 0,
|
||||
'all': { 'round': 0,
|
||||
'inv': 0,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
try:
|
||||
|
||||
@@ -4,6 +4,7 @@ from . import types
|
||||
from . import comparison
|
||||
from . import program
|
||||
from . import util
|
||||
from . import instructions_base
|
||||
|
||||
##
|
||||
## Helper functions for floating point arithmetic
|
||||
@@ -50,13 +51,14 @@ def maskField(a, k, kappa):
|
||||
asm_open(c, a + two_power(k) * r_dprime + r_prime)# + 2**(k-1))
|
||||
return c, r
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def EQZ(a, k, kappa):
|
||||
if program.Program.prog.options.ring:
|
||||
c, r = maskRing(a, k)
|
||||
else:
|
||||
c, r = maskField(a, k, kappa)
|
||||
d = [None]*k
|
||||
for i,b in enumerate(bits(c, k)):
|
||||
for i,b in enumerate(r[0].bit_decompose_clear(c, k)):
|
||||
d[i] = r[i].bit_xor(b)
|
||||
return 1 - types.sint.conv(KOR(d, kappa))
|
||||
|
||||
@@ -299,6 +301,7 @@ def BitDec(a, k, m, kappa, bits_to_compute=None):
|
||||
|
||||
def BitDecRing(a, k, m):
|
||||
n_shift = int(program.Program.prog.options.ring) - m
|
||||
assert(n_shift >= 0)
|
||||
if program.Program.prog.use_dabit:
|
||||
r, r_bits = zip(*(types.sint.get_dabit() for i in range(m)))
|
||||
r = types.sint.bit_compose(r)
|
||||
@@ -328,10 +331,11 @@ def BitDecField(a, k, m, kappa, bits_to_compute=None):
|
||||
print('BitDec assertion failed')
|
||||
print('a =', a.value)
|
||||
print('a mod 2^%d =' % k, (a.value % 2**k))
|
||||
res = r[0].bit_adder(r, list(bits(c,m)))
|
||||
res = r[0].bit_adder(r, list(r[0].bit_decompose_clear(c,m)))
|
||||
return [types.sint.conv(bit) for bit in res]
|
||||
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def Pow2(a, l, kappa):
|
||||
m = int(ceil(log(l, 2)))
|
||||
t = BitDec(a, m, m, kappa)
|
||||
@@ -361,10 +365,16 @@ def B2U_from_Pow2(pow2a, l, kappa):
|
||||
for i in range(l):
|
||||
bit(r[i])
|
||||
r_bits = r
|
||||
comparison.PRandInt(t, kappa)
|
||||
asm_open(c, pow2a + two_power(l) * t + sum(two_power(i)*r[i] for i in range(l)))
|
||||
comparison.program.curr_tape.require_bit_length(l + kappa)
|
||||
c = list(bits(c, l))
|
||||
if program.Program.prog.options.ring:
|
||||
n_shift = int(program.Program.prog.options.ring) - l
|
||||
assert n_shift > 0
|
||||
c = ((pow2a + types.sint.bit_compose(r)) << n_shift).reveal() >> n_shift
|
||||
else:
|
||||
comparison.PRandInt(t, kappa)
|
||||
asm_open(c, pow2a + two_power(l) * t +
|
||||
sum(two_power(i) * r[i] for i in range(l)))
|
||||
comparison.program.curr_tape.require_bit_length(l + kappa)
|
||||
c = list(r_bits[0].bit_decompose_clear(c, l))
|
||||
x = [r_bits[i].bit_xor(c[i]) for i in range(l)]
|
||||
#print ' '.join(str(b.value) for b in x)
|
||||
y = PreOR(x, kappa)
|
||||
@@ -402,10 +412,14 @@ def Trunc(a, l, m, kappa, compute_modulo=False, signed=False):
|
||||
r_prime += t2
|
||||
r_dprime += t1 - t2
|
||||
#assert(r_prime.value == (sum(2**i*x[i].value*r[i].value for i in range(l)) % comparison.program.P))
|
||||
comparison.PRandInt(rk, kappa)
|
||||
r_dprime += two_power(l) * rk
|
||||
#assert(r_dprime.value == (2**l * rk.value + sum(2**i*(1 - x[i].value)*r[i].value for i in range(l)) % comparison.program.P))
|
||||
asm_open(c, a + r_dprime + r_prime)
|
||||
if program.Program.prog.options.ring:
|
||||
n_shift = int(program.Program.prog.options.ring) - l
|
||||
c = ((a + r_dprime + r_prime) << n_shift).reveal() >> n_shift
|
||||
else:
|
||||
comparison.PRandInt(rk, kappa)
|
||||
r_dprime += two_power(l) * rk
|
||||
#assert(r_dprime.value == (2**l * rk.value + sum(2**i*(1 - x[i].value)*r[i].value for i in range(l)) % comparison.program.P))
|
||||
asm_open(c, a + r_dprime + r_prime)
|
||||
for i in range(1,l):
|
||||
ci[i] = c % two_power(i)
|
||||
#assert(ci[i].value == c.value % 2**i)
|
||||
@@ -439,6 +453,12 @@ def TruncInRing(to_shift, l, pow2m):
|
||||
bits = types.intbitint.bit_adder(r_bits, masked.bit_decompose(l))
|
||||
return types.sint.bit_compose(reversed(bits))
|
||||
|
||||
def SplitInRing(a, l, m):
|
||||
pow2m = Pow2(m, l, None)
|
||||
upper = TruncInRing(a, l, pow2m)
|
||||
lower = a - upper * pow2m
|
||||
return lower, upper, pow2m
|
||||
|
||||
def TruncRoundNearestAdjustOverflow(a, length, target_length, kappa):
|
||||
t = comparison.TruncRoundNearest(a, length, length - target_length, kappa)
|
||||
overflow = t.greater_equal(two_power(target_length), target_length + 1, kappa)
|
||||
@@ -496,6 +516,7 @@ def FLRound(x, mode):
|
||||
p = ((p1 + d * a) * (1 - b) + b * away_from_zero * (1 - l)) * (1 - z)
|
||||
return v, p, z, s
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def TruncPr(a, k, m, kappa=None, signed=True):
|
||||
""" Probabilistic truncation [a/2^m + u]
|
||||
where Pr[u = 1] = (a % 2^m) / 2^m
|
||||
@@ -513,8 +534,11 @@ def TruncPrRing(a, k, m, signed=True):
|
||||
n_ring = int(program.Program.prog.options.ring)
|
||||
assert n_ring >= k, '%d too large' % k
|
||||
if k == n_ring:
|
||||
for i in range(m):
|
||||
a += types.sint.get_random_bit() << i
|
||||
if program.Program.prog.use_edabit():
|
||||
a += types.sint.get_edabit(m, True)[0]
|
||||
else:
|
||||
for i in range(m):
|
||||
a += types.sint.get_random_bit() << i
|
||||
return comparison.TruncLeakyInRing(a, k, m, signed=signed)
|
||||
else:
|
||||
from .types import sint
|
||||
@@ -525,13 +549,22 @@ def TruncPrRing(a, k, m, signed=True):
|
||||
trunc_pr(res, a, k, m)
|
||||
else:
|
||||
# extra bit to mask overflow
|
||||
r_bits = [sint.get_random_bit() for i in range(k + 1)]
|
||||
n_shift = n_ring - len(r_bits)
|
||||
tmp = a + sint.bit_compose(r_bits)
|
||||
if program.Program.prog.use_edabit():
|
||||
lower = sint.get_edabit(m, True)[0]
|
||||
upper = sint.get_edabit(k - m, True)[0]
|
||||
msb = sint.get_random_bit()
|
||||
r = (msb << k) + (upper << m) + lower
|
||||
else:
|
||||
r_bits = [sint.get_random_bit() for i in range(k + 1)]
|
||||
r = sint.bit_compose(r_bits)
|
||||
upper = sint.bit_compose(r_bits[m:k])
|
||||
msb = r_bits[-1]
|
||||
n_shift = n_ring - (k + 1)
|
||||
tmp = a + r
|
||||
masked = (tmp << n_shift).reveal()
|
||||
shifted = (masked << 1 >> (n_shift + m + 1))
|
||||
overflow = r_bits[-1].bit_xor(masked >> (n_ring - 1))
|
||||
res = shifted - sint.bit_compose(r_bits[m:k]) + \
|
||||
overflow = msb.bit_xor(masked >> (n_ring - 1))
|
||||
res = shifted - upper + \
|
||||
(overflow << (k - m))
|
||||
if signed:
|
||||
res -= (1 << (k - m - 1))
|
||||
@@ -555,6 +588,7 @@ def TruncPrField(a, k, m, kappa=None):
|
||||
d = (a - a_prime) / two_to_m
|
||||
return d
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def SDiv(a, b, l, kappa, round_nearest=False):
|
||||
theta = int(ceil(log(l / 3.5) / log(2)))
|
||||
alpha = two_power(2*l)
|
||||
@@ -564,7 +598,7 @@ def SDiv(a, b, l, kappa, round_nearest=False):
|
||||
y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False)
|
||||
x2 = types.sint()
|
||||
comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, False)
|
||||
x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True)
|
||||
x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True)
|
||||
for i in range(theta-1):
|
||||
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
|
||||
round_nearest,
|
||||
@@ -576,7 +610,7 @@ def SDiv(a, b, l, kappa, round_nearest=False):
|
||||
signed=False)
|
||||
x2 = types.sint()
|
||||
comparison.Mod2m(x2, x, 2 * l, l, kappa, False)
|
||||
x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True)
|
||||
x1 = comparison.TruncZeros(x - x2, 2 * l + 1, l, True)
|
||||
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
|
||||
round_nearest, signed=False)
|
||||
y = y.round(2 * l + 1, l - 1, kappa, round_nearest)
|
||||
|
||||
@@ -11,8 +11,10 @@ documentation
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import operator
|
||||
from . import tools
|
||||
from random import randint
|
||||
from functools import reduce
|
||||
from Compiler.config import *
|
||||
from Compiler.exceptions import *
|
||||
import Compiler.instructions_base as base
|
||||
@@ -318,6 +320,11 @@ class use_inp(base.Instruction):
|
||||
code = base.opcodes['USE_INP']
|
||||
arg_format = ['int','int','int']
|
||||
|
||||
class use_edabit(base.Instruction):
|
||||
r""" edaBit usage. """
|
||||
code = base.opcodes['USE_EDABIT']
|
||||
arg_format = ['int','int','int']
|
||||
|
||||
class run_tape(base.Instruction):
|
||||
r""" Start tape $n$ in thread $c_i$ with argument $c_j$. """
|
||||
code = base.opcodes['RUN_TAPE']
|
||||
@@ -808,7 +815,29 @@ class dabit(base.DataInstruction):
|
||||
code = base.opcodes['DABIT']
|
||||
arg_format = ['sw', 'sbw']
|
||||
field_type = 'modp'
|
||||
data_type = 'bit'
|
||||
data_type = 'dabit'
|
||||
|
||||
@base.vectorize
|
||||
class edabit(base.Instruction):
|
||||
""" edaBit """
|
||||
__slots__ = []
|
||||
code = base.opcodes['EDABIT']
|
||||
arg_format = tools.chain(['sw'], itertools.repeat('sbw'))
|
||||
field_type = 'modp'
|
||||
|
||||
def add_usage(self, req_node):
|
||||
req_node.increment(('edabit', len(self.args) - 1), self.get_size())
|
||||
|
||||
@base.vectorize
|
||||
class sedabit(base.Instruction):
|
||||
""" strict edaBit """
|
||||
__slots__ = []
|
||||
code = base.opcodes['SEDABIT']
|
||||
arg_format = tools.chain(['sw'], itertools.repeat('sbw'))
|
||||
field_type = 'modp'
|
||||
|
||||
def add_usage(self, req_node):
|
||||
req_node.increment(('sedabit', len(self.args) - 1), self.get_size())
|
||||
|
||||
@base.gf2n
|
||||
@base.vectorize
|
||||
@@ -988,7 +1017,19 @@ class startinput(base.RawInputInstruction):
|
||||
req_node.increment((self.field_type, 'input', self.args[0]), \
|
||||
self.args[1])
|
||||
|
||||
class stopinput(base.RawInputInstruction):
|
||||
def merge(self, other):
|
||||
self.args[1] += other.args[1]
|
||||
|
||||
class StopInputInstruction(base.RawInputInstruction):
|
||||
__slots__ = []
|
||||
|
||||
def merge(self, other):
|
||||
if self.get_size() != other.get_size():
|
||||
raise NotImplemented()
|
||||
else:
|
||||
self.args += other.args[1:]
|
||||
|
||||
class stopinput(StopInputInstruction):
|
||||
r""" Receive inputs from player $p$ and put in registers. """
|
||||
__slots__ = []
|
||||
code = base.opcodes['STOPINPUT']
|
||||
@@ -997,7 +1038,7 @@ class stopinput(base.RawInputInstruction):
|
||||
def has_var_args(self):
|
||||
return True
|
||||
|
||||
class gstopinput(base.RawInputInstruction):
|
||||
class gstopinput(StopInputInstruction):
|
||||
r""" Receive inputs from player $p$ and put in registers. """
|
||||
__slots__ = []
|
||||
code = 0x100 + base.opcodes['STOPINPUT']
|
||||
@@ -1322,6 +1363,26 @@ class bitdecint(base.Instruction):
|
||||
code = base.opcodes['BITDECINT']
|
||||
arg_format = tools.chain(['ci'], itertools.repeat('ciw'))
|
||||
|
||||
class incint(base.VectorInstruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['INCINT']
|
||||
arg_format = ['ciw', 'ci', 'i', 'i', 'i']
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
assert len(args[1]) == 1
|
||||
if len(args) == 3:
|
||||
args = list(args) + [1, len(args[0])]
|
||||
super(incint, self).__init__(*args, **kwargs)
|
||||
|
||||
class shuffle(base.VectorInstruction):
|
||||
__slots__ = []
|
||||
code = base.opcodes['SHUFFLE']
|
||||
arg_format = ['ciw','ci']
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(shuffle, self).__init__(*args, **kwargs)
|
||||
assert len(args[0]) == len(args[1])
|
||||
|
||||
###
|
||||
### Clear comparison instructions
|
||||
###
|
||||
@@ -1429,6 +1490,9 @@ class convmodp(base.Instruction):
|
||||
code = base.opcodes['CONVMODP']
|
||||
arg_format = ['ciw', 'c', 'int']
|
||||
def __init__(self, *args, **kwargs):
|
||||
if len(args) == len(self.arg_format):
|
||||
super(convmodp_class, self).__init__(*args)
|
||||
return
|
||||
bitlength = kwargs.get('bitlength')
|
||||
bitlength = program.bit_length if bitlength is None else bitlength
|
||||
if bitlength > 64:
|
||||
@@ -1472,7 +1536,7 @@ class muls(base.VarArgsInstruction, base.DataInstruction):
|
||||
def merge_id(self):
|
||||
# can merge different sizes
|
||||
# but not if large
|
||||
if self.get_size() > 100:
|
||||
if self.get_size() is None or self.get_size() > 100:
|
||||
return type(self), self.get_size()
|
||||
return type(self)
|
||||
|
||||
@@ -1561,6 +1625,31 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction):
|
||||
for reg in self.args[i + 2:i + self.args[i]]:
|
||||
yield reg
|
||||
|
||||
class matmul_base(base.DataInstruction):
|
||||
data_type = 'triple'
|
||||
is_vec = lambda self: True
|
||||
|
||||
def get_repeat(self):
|
||||
return reduce(operator.mul, self.args[3:6])
|
||||
|
||||
class matmuls(matmul_base):
|
||||
""" Secret matrix multiplication """
|
||||
code = base.opcodes['MATMULS']
|
||||
arg_format = ['sw','s','s','int','int','int']
|
||||
|
||||
class matmulsm(matmul_base):
|
||||
""" Secret matrix multiplication reading directly from memory """
|
||||
code = base.opcodes['MATMULSM']
|
||||
arg_format = ['sw','ci','ci','int','int','int','ci','ci','ci','ci',
|
||||
'int','int']
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
matmul_base.__init__(self, *args, **kwargs)
|
||||
for i in range(2):
|
||||
assert args[6 + i].size == args[3 + i]
|
||||
for i in range(2):
|
||||
assert args[8 + i].size == args[4 + i]
|
||||
|
||||
@base.vectorize
|
||||
class trunc_pr(base.VarArgsInstruction):
|
||||
""" Probalistic truncation for semi-honest computation """
|
||||
|
||||
@@ -8,6 +8,7 @@ from Compiler.exceptions import *
|
||||
from Compiler.config import *
|
||||
from Compiler import util
|
||||
from Compiler import tools
|
||||
from Compiler import program
|
||||
|
||||
|
||||
###
|
||||
@@ -59,6 +60,7 @@ opcodes = dict(
|
||||
NPLAYERS = 0xE2,
|
||||
THRESHOLD = 0xE3,
|
||||
PLAYERID = 0xE4,
|
||||
USE_EDABIT = 0xE5,
|
||||
# Addition
|
||||
ADDC = 0x20,
|
||||
ADDS = 0x21,
|
||||
@@ -93,6 +95,8 @@ opcodes = dict(
|
||||
MULRS = 0xA7,
|
||||
DOTPRODS = 0xA8,
|
||||
TRUNC_PR = 0xA9,
|
||||
MATMULS = 0xAA,
|
||||
MATMULSM = 0xAB,
|
||||
# Data access
|
||||
TRIPLE = 0x50,
|
||||
BIT = 0x51,
|
||||
@@ -103,6 +107,8 @@ opcodes = dict(
|
||||
INPUTMASK = 0x56,
|
||||
PREP = 0x57,
|
||||
DABIT = 0x58,
|
||||
EDABIT = 0x59,
|
||||
SEDABIT = 0x5A,
|
||||
# Input
|
||||
INPUT = 0x60,
|
||||
INPUTFIX = 0xF0,
|
||||
@@ -153,6 +159,8 @@ opcodes = dict(
|
||||
MULINT = 0x9D,
|
||||
DIVINT = 0x9E,
|
||||
PRINTINT = 0x9F,
|
||||
INCINT = 0xD1,
|
||||
SHUFFLE = 0xD2,
|
||||
# Conversion
|
||||
CONVINT = 0xC0,
|
||||
CONVMODP = 0xC1,
|
||||
@@ -235,13 +243,17 @@ def vectorize(instruction, global_dict=None):
|
||||
def __init__(self, size, *args, **kwargs):
|
||||
self.size = size
|
||||
super(Vectorized_Instruction, self).__init__(*args, **kwargs)
|
||||
for arg,f in zip(self.args, self.arg_format):
|
||||
if issubclass(ArgFormats[f], RegisterArgFormat):
|
||||
arg.set_size(size)
|
||||
if not kwargs.get('copying', False):
|
||||
for arg,f in zip(self.args, self.arg_format):
|
||||
if issubclass(ArgFormats[f], RegisterArgFormat):
|
||||
arg.set_size(size)
|
||||
def get_code(self):
|
||||
return (self.size << 10) + self.code
|
||||
return instruction.get_code(self, self.get_size())
|
||||
def get_pre_arg(self):
|
||||
return "%d, " % self.size
|
||||
try:
|
||||
return "%d, " % self.size
|
||||
except:
|
||||
return "{undef}, "
|
||||
def is_vec(self):
|
||||
return True
|
||||
def get_size(self):
|
||||
@@ -250,6 +262,9 @@ def vectorize(instruction, global_dict=None):
|
||||
set_global_vector_size(self.size)
|
||||
super(Vectorized_Instruction, self).expand()
|
||||
reset_global_vector_size()
|
||||
def copy(self, size, subs):
|
||||
return type(self)(size, *self.get_new_args(size, subs),
|
||||
copying=True)
|
||||
|
||||
@functools.wraps(instruction)
|
||||
def maybe_vectorized_instruction(*args, **kwargs):
|
||||
@@ -360,6 +375,169 @@ def gf2n(instruction):
|
||||
return maybe_gf2n_instruction
|
||||
#return instruction
|
||||
|
||||
class Mergeable:
|
||||
pass
|
||||
|
||||
def cisc(function):
|
||||
class MergeCISC(Mergeable):
|
||||
instructions = {}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.calls = [(args, kwargs)]
|
||||
self.params = []
|
||||
self.used = []
|
||||
for arg in self.args[1:]:
|
||||
if isinstance(arg, program.curr_tape.Register):
|
||||
self.used.append(arg)
|
||||
self.params.append(type(arg))
|
||||
else:
|
||||
self.params.append(arg)
|
||||
self.function = function
|
||||
program.curr_block.instructions.append(self)
|
||||
|
||||
def get_def(self):
|
||||
return [self.args[0]]
|
||||
|
||||
def get_used(self):
|
||||
return self.used
|
||||
|
||||
def is_vec(self):
|
||||
return True
|
||||
|
||||
def merge_id(self):
|
||||
return self.function, tuple(self.params), \
|
||||
tuple(sorted(self.kwargs.items()))
|
||||
|
||||
def merge(self, other):
|
||||
self.calls += other.calls
|
||||
|
||||
def get_size(self):
|
||||
return self.args[0].size
|
||||
|
||||
def new_instructions(self, size, regs):
|
||||
if self.merge_id() not in self.instructions:
|
||||
from Compiler.program import Tape
|
||||
tape = Tape(self.function.__name__, program)
|
||||
old_tape = program.curr_tape
|
||||
program.curr_tape = tape
|
||||
block = tape.BasicBlock(tape, None, None)
|
||||
tape.active_basicblock = block
|
||||
set_global_vector_size(None)
|
||||
args = []
|
||||
for arg in self.args:
|
||||
try:
|
||||
args.append(type(arg)(size=None))
|
||||
except:
|
||||
args.append(arg)
|
||||
program.options.cisc = False
|
||||
self.function(*args, **self.kwargs)
|
||||
program.options.cisc = True
|
||||
reset_global_vector_size()
|
||||
program.curr_tape = old_tape
|
||||
from Compiler.allocator import Merger
|
||||
merger = Merger(block, program.options,
|
||||
tuple(program.to_merge))
|
||||
args[0].can_eliminate = False
|
||||
merger.eliminate_dead_code()
|
||||
assert int(program.options.max_parallel_open) == 0, \
|
||||
'merging restriction not compatible with ' \
|
||||
'mergeable CISC instructions'
|
||||
merger.longest_paths_merge()
|
||||
filtered = filter(lambda x: x is not None, block.instructions)
|
||||
self.instructions[self.merge_id()] = list(filtered), args
|
||||
template, args = self.instructions[self.merge_id()]
|
||||
subs = util.dict_by_id()
|
||||
for arg, reg in zip(args, regs):
|
||||
subs[arg] = reg
|
||||
set_global_vector_size(size)
|
||||
for inst in template:
|
||||
inst.copy(size, subs)
|
||||
reset_global_vector_size()
|
||||
|
||||
def expand_merged(self):
|
||||
tape = program.curr_tape
|
||||
block = tape.BasicBlock(tape, None, None)
|
||||
tape.active_basicblock = block
|
||||
size = sum(call[0][0].size for call in self.calls)
|
||||
new_regs = []
|
||||
for arg in self.args:
|
||||
try:
|
||||
new_regs.append(type(arg)(size=size))
|
||||
except:
|
||||
break
|
||||
base = 0
|
||||
for call in self.calls:
|
||||
for new_reg, reg in zip(new_regs[1:], call[0][1:]):
|
||||
set_global_vector_size(reg.size)
|
||||
reg.mov(new_reg.get_vector(base, reg.size), reg)
|
||||
reset_global_vector_size()
|
||||
base += reg.size
|
||||
self.new_instructions(size, new_regs)
|
||||
base = 0
|
||||
for call in self.calls:
|
||||
reg = call[0][0]
|
||||
set_global_vector_size(reg.size)
|
||||
reg.mov(reg, new_regs[0].get_vector(base, reg.size))
|
||||
reset_global_vector_size()
|
||||
base += reg.size
|
||||
return block.instructions
|
||||
|
||||
MergeCISC.__name__ = function.__name__
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if program.options.cisc:
|
||||
return MergeCISC(*args, **kwargs)
|
||||
else:
|
||||
return function(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
def ret_cisc(function):
|
||||
def instruction(res, *args, **kwargs):
|
||||
res.mov(res, function(*args, **kwargs))
|
||||
instruction.__name__ = function.__name__
|
||||
instruction = cisc(instruction)
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if not program.options.cisc:
|
||||
return function(*args, **kwargs)
|
||||
from Compiler import types
|
||||
if isinstance(args[0], types._clear):
|
||||
res_type = type(args[1])
|
||||
else:
|
||||
res_type = type(args[0])
|
||||
res = res_type(size=args[0].size)
|
||||
instruction(res, *args, **kwargs)
|
||||
return res
|
||||
return wrapper
|
||||
|
||||
def sfix_cisc(function):
|
||||
from Compiler.types import sfix, sint, cfix, copy_doc
|
||||
def instruction(res, arg, k, f):
|
||||
assert k is not None
|
||||
assert f is not None
|
||||
old = sfix.k, sfix.f, cfix.k, cfix.f
|
||||
sfix.k, sfix.f, cfix.k, cfix.f = [None] * 4
|
||||
res.mov(res, function(sfix._new(arg, k=k, f=f)).v)
|
||||
sfix.k, sfix.f, cfix.k, cfix.f = old
|
||||
instruction.__name__ = function.__name__
|
||||
instruction = cisc(instruction)
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if isinstance(args[0], sfix):
|
||||
assert len(args) == 1
|
||||
assert not kwargs
|
||||
assert args[0].size == args[0].v.size
|
||||
k = args[0].k
|
||||
f = args[0].f
|
||||
res = sfix._new(sint(size=args[0].size), k=k, f=f)
|
||||
instruction(res.v, args[0].v, k, f)
|
||||
return res
|
||||
else:
|
||||
return function(*args, **kwargs)
|
||||
copy_doc(wrapper, function)
|
||||
return wrapper
|
||||
|
||||
class RegType(object):
|
||||
""" enum-like static class for Register types """
|
||||
@@ -381,6 +559,8 @@ class RegType(object):
|
||||
return res
|
||||
|
||||
class ArgFormat(object):
|
||||
is_reg = False
|
||||
|
||||
@classmethod
|
||||
def check(cls, arg):
|
||||
return NotImplemented
|
||||
@@ -390,11 +570,13 @@ class ArgFormat(object):
|
||||
return NotImplemented
|
||||
|
||||
class RegisterArgFormat(ArgFormat):
|
||||
is_reg = True
|
||||
|
||||
@classmethod
|
||||
def check(cls, arg):
|
||||
if not isinstance(arg, program.curr_tape.Register):
|
||||
raise ArgumentError(arg, 'Invalid register argument')
|
||||
if arg.i > REG_MAX:
|
||||
if arg.i > REG_MAX and arg.i != float('inf'):
|
||||
raise ArgumentError(arg, 'Register index too large')
|
||||
if arg.program != program.curr_tape:
|
||||
raise ArgumentError(arg, 'Register from other tape, trace: %s' % \
|
||||
@@ -425,7 +607,7 @@ class ClearIntAF(RegisterArgFormat):
|
||||
class IntArgFormat(ArgFormat):
|
||||
@classmethod
|
||||
def check(cls, arg):
|
||||
if not isinstance(arg, int):
|
||||
if not isinstance(arg, int) and not arg is None:
|
||||
raise ArgumentError(arg, 'Expected an integer-valued argument')
|
||||
|
||||
@classmethod
|
||||
@@ -487,7 +669,7 @@ ArgFormats = {
|
||||
}
|
||||
|
||||
def format_str_is_reg(format_str):
|
||||
return issubclass(ArgFormats[format_str], RegisterArgFormat)
|
||||
return ArgFormats[format_str].is_reg
|
||||
|
||||
def format_str_is_writeable(format_str):
|
||||
return format_str_is_reg(format_str) and format_str[-1] == 'w'
|
||||
@@ -504,7 +686,8 @@ class Instruction(object):
|
||||
def __init__(self, *args, **kwargs):
|
||||
""" Create an instruction and append it to the program list. """
|
||||
self.args = list(args)
|
||||
self.check_args()
|
||||
if not kwargs.get('copying', False):
|
||||
self.check_args()
|
||||
if not program.FIRST_PASS:
|
||||
if kwargs.get('add_to_prog', True):
|
||||
program.curr_block.instructions.append(self)
|
||||
@@ -519,9 +702,9 @@ class Instruction(object):
|
||||
if Instruction.count % 100000 == 0:
|
||||
print("Compiled %d lines at" % self.__class__.count, time.asctime())
|
||||
|
||||
def get_code(self):
|
||||
return self.code
|
||||
|
||||
def get_code(self, prefix=0):
|
||||
return (prefix << 10) + self.code
|
||||
|
||||
def get_encoding(self):
|
||||
enc = int_to_bytes(self.get_code())
|
||||
# add the number of registers if instruction flagged as has var args
|
||||
@@ -540,12 +723,12 @@ class Instruction(object):
|
||||
|
||||
def check_args(self):
|
||||
""" Check the args match up with that specified in arg_format """
|
||||
for n,(arg,f) in enumerate(itertools.zip_longest(self.args, self.arg_format)):
|
||||
if arg is None:
|
||||
if not isinstance(self.arg_format, (list, tuple)):
|
||||
break # end of optional arguments
|
||||
else:
|
||||
raise CompilerError('Incorrect number of arguments for instruction %s' % (self))
|
||||
try:
|
||||
if len(self.args) != len(self.arg_format):
|
||||
raise CompilerError('Incorrect number of arguments for instruction %s' % (self))
|
||||
except TypeError:
|
||||
pass
|
||||
for n,(arg,f) in enumerate(zip(self.args, self.arg_format)):
|
||||
try:
|
||||
ArgFormats[f].check(arg)
|
||||
except ArgumentError as e:
|
||||
@@ -589,6 +772,42 @@ class Instruction(object):
|
||||
def merge_id(self):
|
||||
return type(self), self.get_size()
|
||||
|
||||
def merge(self, other):
|
||||
if self.get_size() != other.get_size():
|
||||
# merge as non-vector instruction
|
||||
self.args = self.expand_vector_args() + other.expand_vector_args()
|
||||
if self.is_vec():
|
||||
self.size = 1
|
||||
else:
|
||||
self.args += other.args
|
||||
|
||||
def expand_vector_args(self):
|
||||
if self.is_vec():
|
||||
for arg in self.args:
|
||||
arg.create_vector_elements()
|
||||
res = sum(list(zip(*self.args)), ())
|
||||
return list(res)
|
||||
else:
|
||||
return self.args
|
||||
|
||||
def expand_merged(self):
|
||||
return [self]
|
||||
|
||||
def get_new_args(self, size, subs):
|
||||
new_args = []
|
||||
for arg, f in zip(self.args, self.arg_format):
|
||||
if arg in subs:
|
||||
new_args.append(subs[arg])
|
||||
elif arg is None:
|
||||
new_args.append(size)
|
||||
else:
|
||||
if format_str_is_writeable(f):
|
||||
new_args.append(arg.copy())
|
||||
subs[arg] = new_args[-1]
|
||||
else:
|
||||
new_args.append(arg)
|
||||
return new_args
|
||||
|
||||
# String version of instruction attempting to replicate encoded version
|
||||
def __str__(self):
|
||||
|
||||
@@ -606,6 +825,13 @@ class VarArgsInstruction(Instruction):
|
||||
def has_var_args(self):
|
||||
return True
|
||||
|
||||
class VectorInstruction(Instruction):
|
||||
__slots__ = []
|
||||
is_vec = lambda self: True
|
||||
|
||||
def get_code(self):
|
||||
return super(VectorInstruction, self).get_code(len(self.args[0]))
|
||||
|
||||
###
|
||||
### Basic arithmetic
|
||||
###
|
||||
|
||||
@@ -22,7 +22,7 @@ def get_block():
|
||||
|
||||
def vectorize(function):
|
||||
def vectorized_function(*args, **kwargs):
|
||||
if len(args) > 0 and isinstance(args[0], program.Tape.Register):
|
||||
if len(args) > 0 and 'size' in dir(args[0]):
|
||||
instructions_base.set_global_vector_size(args[0].size)
|
||||
res = function(*args, **kwargs)
|
||||
instructions_base.reset_global_vector_size()
|
||||
@@ -88,7 +88,7 @@ def print_str(s, *args):
|
||||
raise CompilerError('Cannot print secret value:', args[i])
|
||||
elif isinstance(val, cfloat):
|
||||
val.print_float_plain()
|
||||
elif isinstance(val, list):
|
||||
elif isinstance(val, (list, tuple, Array)):
|
||||
print_str('[' + ', '.join('%s' for i in range(len(val))) + ']', *val)
|
||||
else:
|
||||
try:
|
||||
@@ -362,9 +362,14 @@ class Function:
|
||||
|
||||
class FunctionTape(Function):
|
||||
# not thread-safe
|
||||
def __init__(self, function, name=None, compile_args=[],
|
||||
single_thread=False):
|
||||
Function.__init__(self, function, name, compile_args)
|
||||
self.single_thread = single_thread
|
||||
def on_first_call(self, wrapped_function):
|
||||
self.thread = MPCThread(wrapped_function, self.name,
|
||||
args=self.compile_args)
|
||||
args=self.compile_args,
|
||||
single_thread=self.single_thread)
|
||||
def on_call(self, base, bases):
|
||||
return FunctionTapeCall(self.thread, base, bases)
|
||||
|
||||
@@ -376,6 +381,9 @@ def function_tape_with_compile_args(*args):
|
||||
return FunctionTape(function, compile_args=args)
|
||||
return wrapper
|
||||
|
||||
def single_thread_function_tape(function):
|
||||
return FunctionTape(function, single_thread=True)
|
||||
|
||||
def memorize(x):
|
||||
if isinstance(x, (tuple, list)):
|
||||
return tuple(memorize(i) for i in x)
|
||||
@@ -397,13 +405,15 @@ class FunctionBlock(Function):
|
||||
block.alloc_pool = defaultdict(set)
|
||||
del parent_node.children[-1]
|
||||
self.node = get_tape().req_node
|
||||
print('Compiling function', self.name)
|
||||
if get_program().verbose:
|
||||
print('Compiling function', self.name)
|
||||
result = wrapped_function(*self.compile_args)
|
||||
if result is not None:
|
||||
self.result = memorize(result)
|
||||
else:
|
||||
self.result = None
|
||||
print('Done compiling function', self.name)
|
||||
if get_program().verbose:
|
||||
print('Done compiling function', self.name)
|
||||
p_return_address = get_tape().program.malloc(1, 'ci')
|
||||
get_tape().function_basicblocks[block] = p_return_address
|
||||
return_address = regint.load_mem(p_return_address)
|
||||
@@ -528,7 +538,7 @@ def chunky_odd_even_merge_sort(a):
|
||||
a[m], a[m+step] = cond_swap(a[m], a[m+step])
|
||||
for i in range(len(a)):
|
||||
a[i].store_in_mem(i * a[i].sizeof())
|
||||
chunk = MPCThread(round, 'sort-%d-%d' % (l,k))
|
||||
chunk = MPCThread(round, 'sort-%d-%d' % (l,k), single_thread=True)
|
||||
chunk.start()
|
||||
chunk.join()
|
||||
#round()
|
||||
@@ -541,7 +551,6 @@ def chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7, use
|
||||
a_base = instructions.program.malloc(n, 's')
|
||||
for i,j in enumerate(a):
|
||||
store_in_mem(j, a_base + i)
|
||||
instructions.program.restart_main_thread()
|
||||
else:
|
||||
a_base = a
|
||||
tmp_base = instructions.program.malloc(n, 's')
|
||||
@@ -657,7 +666,6 @@ def chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7, use
|
||||
run_postproc()
|
||||
|
||||
if isinstance(a, list):
|
||||
instructions.program.restart_main_thread()
|
||||
for i in range(n):
|
||||
a[i] = load_secret_mem(a_base + i)
|
||||
instructions.program.free(a_base, 's')
|
||||
@@ -669,7 +677,6 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=
|
||||
a_base = instructions.program.malloc(n, 's')
|
||||
for i,j in enumerate(a):
|
||||
store_in_mem(j, a_base + i)
|
||||
instructions.program.restart_main_thread()
|
||||
else:
|
||||
a_base = a
|
||||
tmp_base = instructions.program.malloc(n, 's')
|
||||
@@ -764,7 +771,6 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=
|
||||
range_loop(outer, n // l)
|
||||
|
||||
if isinstance(a, list):
|
||||
instructions.program.restart_main_thread()
|
||||
for i in range(n):
|
||||
a[i] = load_secret_mem(a_base + i)
|
||||
instructions.program.free(a_base, 's')
|
||||
@@ -772,30 +778,39 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=
|
||||
instructions.program.free(tmp_i, 'ci')
|
||||
|
||||
|
||||
def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32):
|
||||
def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32,
|
||||
n_threads=None):
|
||||
steps = {}
|
||||
l = sorted_length
|
||||
while l < len(a):
|
||||
l *= 2
|
||||
k = 1
|
||||
while k < l:
|
||||
k *= 2
|
||||
n_outer = len(a) // l
|
||||
n_inner = l // k
|
||||
n_innermost = 1 if k == 2 else k // 2 - 1
|
||||
@for_range_parallel(n_parallel // n_innermost // n_inner, n_outer)
|
||||
def loop(i):
|
||||
@for_range_parallel(n_parallel // n_innermost, n_inner)
|
||||
def inner(j):
|
||||
base = i*l + j
|
||||
step = l//k
|
||||
if k == 2:
|
||||
a[base], a[base+step] = cond_swap(a[base], a[base+step])
|
||||
else:
|
||||
@for_range_parallel(n_parallel, n_innermost)
|
||||
def f(i):
|
||||
m1 = step + i * 2 * step
|
||||
m2 = m1 + base
|
||||
a[m2], a[m2+step] = cond_swap(a[m2], a[m2+step])
|
||||
key = k
|
||||
if key not in steps:
|
||||
@function_block
|
||||
def step(l):
|
||||
l = MemValue(l)
|
||||
@for_range_opt_multithread(n_threads, len(a) // k)
|
||||
def _(i):
|
||||
n_inner = l // k
|
||||
j = i % n_inner
|
||||
i //= n_inner
|
||||
base = i*l + j
|
||||
step = l//k
|
||||
if k == 2:
|
||||
a[base], a[base+step] = \
|
||||
cond_swap(a[base], a[base+step])
|
||||
else:
|
||||
@for_range_opt(n_innermost)
|
||||
def f(i):
|
||||
m1 = step + i * 2 * step
|
||||
m2 = m1 + base
|
||||
a[m2], a[m2+step] = cond_swap(a[m2], a[m2+step])
|
||||
steps[key] = step
|
||||
steps[key](l)
|
||||
|
||||
def mergesort(A):
|
||||
B = Array(len(A), sint)
|
||||
@@ -899,11 +914,12 @@ def for_range_parallel(n_parallel, n_loops):
|
||||
def for_range_opt(n_loops, budget=None):
|
||||
""" Execute loop bodies in parallel up to an optimization budget.
|
||||
This prevents excessive loop unrolling. The budget is respected
|
||||
even with nested loops.
|
||||
even with nested loops. Note that optimization is rather
|
||||
rudimentary for runtime :py:obj:`n_loops` (regint/cint). Consider
|
||||
using :py:func:`for_range_parallel` in this case.
|
||||
|
||||
:param n_loops: compile-time (int)
|
||||
:param n_loops: int/regint/cint
|
||||
:param budget: number of instructions after which to start optimization (default is 100,000)
|
||||
:type: compile-time (int)
|
||||
|
||||
Example:
|
||||
|
||||
@@ -912,6 +928,7 @@ def for_range_opt(n_loops, budget=None):
|
||||
@for_range_opt(n)
|
||||
def _(i):
|
||||
...
|
||||
|
||||
"""
|
||||
return map_reduce_single(None, n_loops, budget=budget)
|
||||
|
||||
@@ -929,6 +946,8 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
else:
|
||||
# use Arrays for multithread version
|
||||
use_array = True
|
||||
if not util.is_constant(n_loops):
|
||||
budget //= 10
|
||||
def decorator(loop_body):
|
||||
my_n_parallel = n_parallel
|
||||
if isinstance(n_parallel, int):
|
||||
@@ -955,17 +974,18 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
r = reducer(mem_state, state)
|
||||
write_state_to_memory(r)
|
||||
else:
|
||||
if n_loops == 0:
|
||||
if is_zero(n_loops):
|
||||
return
|
||||
regint.push(0)
|
||||
n_opt_loops_reg = regint(0)
|
||||
n_opt_loops_inst = get_block().instructions[-1]
|
||||
parent_block = get_block()
|
||||
@while_do(lambda x: x + regint.pop() <= n_loops, regint(0))
|
||||
@while_do(lambda x: x + n_opt_loops_reg <= n_loops, regint(0))
|
||||
def _(i):
|
||||
state = tuplify(initializer())
|
||||
k = 0
|
||||
block = get_block()
|
||||
while k < n_loops and (len(get_block()) < budget \
|
||||
or k == 0) \
|
||||
while (not util.is_constant(n_loops) or k < n_loops) \
|
||||
and (len(get_block()) < budget or k == 0) \
|
||||
and block is get_block():
|
||||
j = i + k
|
||||
state = reducer(tuplify(loop_body(j)), state)
|
||||
@@ -974,13 +994,13 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
write_state_to_memory(r)
|
||||
global n_opt_loops
|
||||
n_opt_loops = k
|
||||
regint.push(k)
|
||||
n_opt_loops_inst.args[1] = k
|
||||
return i + k
|
||||
my_n_parallel = n_opt_loops
|
||||
loop_rounds = n_loops // my_n_parallel
|
||||
blocks = get_tape().basicblocks
|
||||
n_to_merge = 5
|
||||
if loop_rounds == 1 and parent_block is blocks[-n_to_merge]:
|
||||
if util.is_one(loop_rounds) and parent_block is blocks[-n_to_merge]:
|
||||
# merge blocks started by if and do_while
|
||||
def exit_elimination(block):
|
||||
if block.exit_condition is not None:
|
||||
@@ -996,14 +1016,15 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
for block in blocks[-n_to_merge + 1:]:
|
||||
merged.instructions += block.instructions
|
||||
exit_elimination(block)
|
||||
block.purge()
|
||||
block.purge(retain_usage=False)
|
||||
del blocks[-n_to_merge + 1:]
|
||||
del get_tape().req_node.children[-1]
|
||||
merged.children = []
|
||||
get_tape().active_basicblock = merged
|
||||
else:
|
||||
req_node = get_tape().req_node.children[-1].nodes[0]
|
||||
req_node.children[0].aggregator = lambda x: loop_rounds * x[0]
|
||||
if util.is_constant(loop_rounds):
|
||||
req_node.children[0].aggregator = lambda x: loop_rounds * x[0]
|
||||
if isinstance(n_loops, int):
|
||||
state = mem_state
|
||||
for j in range(loop_rounds * my_n_parallel, n_loops):
|
||||
@@ -1040,7 +1061,9 @@ def for_range_opt_multithread(n_threads, n_loops):
|
||||
"""
|
||||
Execute :py:obj:`n_loops` loop bodies in up to :py:obj:`n_threads`
|
||||
threads, in parallel up to an optimization budget per thread
|
||||
similar to :py:func:`for_range_opt`.
|
||||
similar to :py:func:`for_range_opt`. Note that optimization is rather
|
||||
rudimentary for runtime :py:obj:`n_loops` (regint/cint). Consider
|
||||
using :py:func:`for_range_multithread` in this case.
|
||||
|
||||
:param n_threads: compile-time (int)
|
||||
:param n_loops: regint/cint/int
|
||||
@@ -1089,7 +1112,7 @@ def multithread(n_threads, n_items):
|
||||
|
||||
def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
thread_mem_req={}, looping=True):
|
||||
n_threads = n_threads or 1
|
||||
assert(n_threads != 0)
|
||||
if isinstance(n_loops, list):
|
||||
split = n_loops
|
||||
n_loops = reduce(operator.mul, n_loops)
|
||||
@@ -1103,9 +1126,22 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
return new_body
|
||||
new_dec = map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, thread_mem_req)
|
||||
return lambda loop_body: new_dec(decorator(loop_body))
|
||||
n_loops = MemValue.if_necessary(n_loops)
|
||||
if n_threads == None or util.is_one(n_loops):
|
||||
if not looping:
|
||||
return lambda loop_body: loop_body(0, n_loops)
|
||||
dec = map_reduce_single(n_parallel, n_loops, initializer, reducer)
|
||||
if thread_mem_req:
|
||||
thread_mem = Array(thread_mem_req[regint], regint)
|
||||
return lambda loop_body: dec(lambda i: loop_body(i, thread_mem))
|
||||
else:
|
||||
return dec
|
||||
def decorator(loop_body):
|
||||
thread_rounds = n_loops // n_threads
|
||||
remainder = n_loops % n_threads
|
||||
thread_rounds = MemValue.if_necessary(n_loops // n_threads)
|
||||
if util.is_constant(thread_rounds):
|
||||
remainder = n_loops % n_threads
|
||||
else:
|
||||
remainder = 0
|
||||
for t in thread_mem_req:
|
||||
if t != regint:
|
||||
raise CompilerError('Not implemented for other than regint')
|
||||
@@ -1113,6 +1149,11 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
state = tuple(initializer())
|
||||
def f(inc):
|
||||
base = args[get_arg()][0]
|
||||
if not util.is_constant(thread_rounds):
|
||||
i = base / thread_rounds
|
||||
overhang = n_loops % n_threads
|
||||
inc = i < overhang
|
||||
base += inc.if_else(i, overhang)
|
||||
if not looping:
|
||||
return loop_body(base, thread_rounds + inc)
|
||||
if thread_mem_req:
|
||||
@@ -1129,7 +1170,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
return loop_body(base + i)
|
||||
prog = get_program()
|
||||
threads = []
|
||||
if thread_rounds:
|
||||
if not util.is_zero(thread_rounds):
|
||||
tape = prog.new_tape(f, (0,), 'multithread')
|
||||
for i in range(n_threads - remainder):
|
||||
mem_state = make_array(initializer())
|
||||
@@ -1465,6 +1506,15 @@ def get_player_id():
|
||||
playerid(res._v)
|
||||
return res
|
||||
|
||||
def break_point(name=''):
|
||||
"""
|
||||
Insert break point. This makes sure that all following code
|
||||
will be executed after preceding code.
|
||||
|
||||
:param name: Name for identification (optional)
|
||||
"""
|
||||
get_tape().start_new_basicblock(name=name)
|
||||
|
||||
# Fixed point ops
|
||||
|
||||
from math import ceil, log
|
||||
@@ -1590,6 +1640,7 @@ def IntDiv(a, b, k, kappa=None):
|
||||
return FPDiv(a.extend(2 * k) << k, b.extend(2 * k) << k, 2 * k, k,
|
||||
kappa, nearest=True)
|
||||
|
||||
@instructions_base.ret_cisc
|
||||
def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
|
||||
"""
|
||||
Goldschmidt method as presented in Catrina10,
|
||||
|
||||
278
Compiler/ml.py
278
Compiler/ml.py
@@ -1,9 +1,37 @@
|
||||
import mpc_math, math
|
||||
"""
|
||||
This module contains machine learning functionality. It is work in
|
||||
progress, so you must expect things to change. The most tested
|
||||
functionality is logistic regression. It can be run as follows::
|
||||
|
||||
sgd = ml.SGD([ml.Dense(n_examples, n_features, 1),
|
||||
ml.Output(n_examples, approx=True)], n_epochs,
|
||||
report_loss=True)
|
||||
sgd.layers[0].X.input_from(0)
|
||||
sgd.layers[1].Y.input_from(1)
|
||||
sgd.reset()
|
||||
sgd.run()
|
||||
|
||||
This loads measurements from party 0 and labels (0/1) from party
|
||||
1. After running, the model is stored in :py:obj:`sgd.layers[0].W` and
|
||||
:py:obj:`sgd.layers[1].b`. The :py:obj:`approx` parameter determines
|
||||
whether to use an approximate sigmoid function. Inference can be run as
|
||||
follows::
|
||||
|
||||
data = sfix.Matrix(n_test, n_features)
|
||||
data.input_from(0)
|
||||
res = sgd.eval(data)
|
||||
print_ln('Results: %s', [x.reveal() for x in res])
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
from Compiler import mpc_math
|
||||
from Compiler.types import *
|
||||
from Compiler.types import _unreduced_squant
|
||||
from Compiler.library import *
|
||||
from Compiler.util import is_zero
|
||||
from Compiler.util import is_zero, tree_reduce
|
||||
from Compiler.comparison import CarryOutRawLE
|
||||
from Compiler.GC.types import sbitint
|
||||
from functools import reduce
|
||||
|
||||
def log_e(x):
|
||||
@@ -31,6 +59,29 @@ def sigmoid_prime(x):
|
||||
sx = sigmoid(x)
|
||||
return sx * (1 - sx)
|
||||
|
||||
@vectorize
|
||||
def approx_sigmoid(x):
|
||||
if approx_sigmoid.special and \
|
||||
get_program().options.ring and get_program().use_edabit():
|
||||
l = int(get_program().options.ring)
|
||||
r, r_bits = sint.get_edabit(x.k, False)
|
||||
c = ((x.v - r) << (l - x.k)).reveal() >> (l - x.k)
|
||||
c_bits = c.bit_decompose(x.k)
|
||||
lower_overflow = CarryOutRawLE(c_bits[:x.f - 1], r_bits[:x.f - 1])
|
||||
higher_bits = sbitint.bit_adder(c_bits[x.f - 1:], r_bits[x.f - 1:],
|
||||
lower_overflow)
|
||||
sign = higher_bits[-1]
|
||||
higher_bits.pop(-1)
|
||||
aa = sign & ~util.tree_reduce(operator.and_, higher_bits)
|
||||
bb = ~sign & ~util.tree_reduce(operator.and_, [~x for x in higher_bits])
|
||||
a, b = (sint.conv(x) for x in (aa, bb))
|
||||
else:
|
||||
a = x < -0.5
|
||||
b = x > 0.5
|
||||
return a.if_else(0, b.if_else(1, 0.5 + x))
|
||||
|
||||
approx_sigmoid.special = False
|
||||
|
||||
def lse_0_from_e_x(x, e_x):
|
||||
return sanitize(-x, log_e(1 + e_x), x + 2 ** -x.f, 0)
|
||||
|
||||
@@ -56,7 +107,7 @@ class Layer:
|
||||
n_threads = 1
|
||||
|
||||
class Output(Layer):
|
||||
def __init__(self, N, debug=False):
|
||||
def __init__(self, N, debug=False, approx=False):
|
||||
self.N = N
|
||||
self.X = sfix.Array(N)
|
||||
self.Y = sfix.Array(N)
|
||||
@@ -64,9 +115,8 @@ class Output(Layer):
|
||||
self.l = MemValue(sfix(-1))
|
||||
self.e_x = sfix.Array(N)
|
||||
self.debug = debug
|
||||
self.weights = cint.Array(N)
|
||||
self.weights.assign_all(1)
|
||||
self.weight_total = N
|
||||
self.weights = None
|
||||
self.approx = approx
|
||||
|
||||
nablas = lambda self: ()
|
||||
thetas = lambda self: ()
|
||||
@@ -75,30 +125,45 @@ class Output(Layer):
|
||||
def divisor(self, divisor, size):
|
||||
return cfix(1.0 / divisor, size=size)
|
||||
|
||||
def forward(self, N=None):
|
||||
N = N or self.N
|
||||
def forward(self, batch):
|
||||
if self.approx:
|
||||
self.l.write(999)
|
||||
return
|
||||
N = len(batch)
|
||||
lse = sfix.Array(N)
|
||||
@multithread(self.n_threads, N)
|
||||
def _(base, size):
|
||||
x = self.X.get_vector(base, size)
|
||||
y = self.Y.get_vector(base, size)
|
||||
y = self.Y.get(batch.get_vector(base, size))
|
||||
e_x = exp(-x)
|
||||
self.e_x.assign(e_x, base)
|
||||
lse.assign(lse_0_from_e_x(-x, e_x) + x * (1 - y), base)
|
||||
e_x = self.e_x.get_vector(0, N)
|
||||
self.l.write(sum(lse) * \
|
||||
self.divisor(self.N, 1))
|
||||
self.divisor(N, 1))
|
||||
|
||||
def backward(self):
|
||||
@multithread(self.n_threads, self.N)
|
||||
def eval(self, size, base=0):
|
||||
if self.approx:
|
||||
return approx_sigmoid(self.X.get_vector(base, size))
|
||||
else:
|
||||
return sigmoid_from_e_x(self.X.get_vector(base, size),
|
||||
self.e_x.get_vector(base, size))
|
||||
|
||||
def backward(self, batch):
|
||||
N = len(batch)
|
||||
@multithread(self.n_threads, N)
|
||||
def _(base, size):
|
||||
diff = sigmoid_from_e_x(self.X.get_vector(base, size),
|
||||
self.e_x.get_vector(base, size)) - \
|
||||
self.Y.get_vector(base, size)
|
||||
diff = self.eval(size, base) - \
|
||||
self.Y.get(batch.get_vector(base, size))
|
||||
assert sfix.f == cfix.f
|
||||
diff *= self.weights.get_vector(base, size)
|
||||
self.nabla_X.assign(diff * self.divisor(self.weight_total, size), \
|
||||
base)
|
||||
if self.weights is None:
|
||||
diff *= self.divisor(N, size)
|
||||
else:
|
||||
assert N == len(self.weights)
|
||||
diff *= self.weights.get_vector(base, size)
|
||||
if self.weight_total != 1:
|
||||
diff *= self.divisor(self.weight_total, size)
|
||||
self.nabla_X.assign(diff, base)
|
||||
# @for_range_opt(len(diff))
|
||||
# def _(i):
|
||||
# self.nabla_X[i] = self.nabla_X[i] * self.weights[i]
|
||||
@@ -112,6 +177,7 @@ class Output(Layer):
|
||||
#print_ln('%s', x)
|
||||
|
||||
def set_weights(self, weights):
|
||||
self.weights = cfix.Array(len(weights))
|
||||
self.weights.assign(weights)
|
||||
self.weight_total = sum(weights)
|
||||
|
||||
@@ -119,16 +185,28 @@ class DenseBase(Layer):
|
||||
thetas = lambda self: (self.W, self.b)
|
||||
nablas = lambda self: (self.nabla_W, self.nabla_b)
|
||||
|
||||
def backward_params(self, f_schur_Y):
|
||||
N = self.N
|
||||
def backward_params(self, f_schur_Y, batch):
|
||||
N = len(batch)
|
||||
tmp = Matrix(self.d_in, self.d_out, unreduced_sfix)
|
||||
|
||||
@for_range_opt_multithread(self.n_threads, [self.d_in, self.d_out])
|
||||
def _(j, k):
|
||||
assert self.d == 1
|
||||
a = [f_schur_Y[i][0][k] for i in range(N)]
|
||||
b = [self.X[i][0][j] for i in range(N)]
|
||||
tmp[j][k] = sfix.unreduced_dot_product(a, b)
|
||||
assert self.d == 1
|
||||
if self.d_out == 1:
|
||||
@multithread(self.n_threads, self.d_in)
|
||||
def _(base, size):
|
||||
A = sfix.Matrix(1, self.N, address=f_schur_Y.address)
|
||||
B = sfix.Matrix(self.N, self.d_in, address=self.X.address)
|
||||
mp = A.direct_mul(B, reduce=False,
|
||||
indices=(regint(0, size=1),
|
||||
regint.inc(N),
|
||||
batch.get_vector(),
|
||||
regint.inc(size, base)))
|
||||
tmp.assign_vector(mp, base)
|
||||
else:
|
||||
@for_range_opt_multithread(self.n_threads, [self.d_in, self.d_out])
|
||||
def _(j, k):
|
||||
a = [f_schur_Y[i][0][k] for i in range(N)]
|
||||
b = [self.X[i][0][j] for i in batch]
|
||||
tmp[j][k] = sfix.unreduced_dot_product(a, b)
|
||||
|
||||
if self.d_in * self.d_out < 100000:
|
||||
print('reduce at once')
|
||||
@@ -189,26 +267,34 @@ class Dense(DenseBase):
|
||||
self.W[i][j] = sfix.get_random(-r, r)
|
||||
self.b.assign_all(0)
|
||||
|
||||
def compute_f_input(self):
|
||||
prod = MultiArray([self.N, self.d, self.d_out], sfix)
|
||||
@for_range_opt_multithread(self.n_threads, self.N)
|
||||
def _(i):
|
||||
self.X[i].plain_mul(self.W, res=prod[i])
|
||||
def compute_f_input(self, batch):
|
||||
N = len(batch)
|
||||
prod = MultiArray([N, self.d, self.d_out], sfix)
|
||||
assert self.d == 1
|
||||
assert self.d_out == 1
|
||||
@multithread(self.n_threads, N)
|
||||
def _(base, size):
|
||||
X_sub = sfix.Matrix(self.N, self.d_in, address=self.X.address)
|
||||
prod.assign_vector(
|
||||
X_sub.direct_mul(self.W, indices=(batch.get_vector(base, size),
|
||||
regint.inc(self.d_in),
|
||||
regint.inc(self.d_in),
|
||||
regint.inc(self.d_out))),
|
||||
base)
|
||||
|
||||
@for_range_opt_multithread(self.n_threads, self.N)
|
||||
def _(i):
|
||||
@for_range_opt(self.d)
|
||||
def _(j):
|
||||
v = prod[i][j].get_vector() + self.b.get_vector()
|
||||
self.f_input[i][j].assign(v)
|
||||
@multithread(self.n_threads, N)
|
||||
def _(base, size):
|
||||
v = prod.get_vector(base, size) + self.b.expand_to_vector(0, size)
|
||||
self.f_input.assign_vector(v, base)
|
||||
progress('f input')
|
||||
|
||||
def forward(self):
|
||||
self.compute_f_input()
|
||||
self.Y.assign_vector(self.f(self.f_input.get_vector()))
|
||||
def forward(self, batch=None):
|
||||
self.compute_f_input(batch=batch)
|
||||
self.Y.assign_vector(self.f(
|
||||
self.f_input.get_part_vector(0, len(batch))))
|
||||
|
||||
def backward(self, compute_nabla_X=True):
|
||||
N = self.N
|
||||
def backward(self, compute_nabla_X=True, batch=None):
|
||||
N = len(batch)
|
||||
d = self.d
|
||||
d_out = self.d_out
|
||||
X = self.X
|
||||
@@ -233,6 +319,7 @@ class Dense(DenseBase):
|
||||
|
||||
@for_range_opt(N)
|
||||
def _(i):
|
||||
i = batch[i]
|
||||
f_schur_Y[i] = nabla_Y[i].schur(f_prime_bit[i])
|
||||
|
||||
progress('f prime schur Y')
|
||||
@@ -240,6 +327,7 @@ class Dense(DenseBase):
|
||||
if compute_nabla_X:
|
||||
@for_range_opt(N)
|
||||
def _(i):
|
||||
i = batch[i]
|
||||
if self.activation == 'id':
|
||||
nabla_X[i] = nabla_Y[i].mul_trans(W)
|
||||
else:
|
||||
@@ -247,7 +335,7 @@ class Dense(DenseBase):
|
||||
|
||||
progress('nabla X')
|
||||
|
||||
self.backward_params(f_schur_Y)
|
||||
self.backward_params(f_schur_Y, batch=batch)
|
||||
|
||||
class QuantizedDense(DenseBase):
|
||||
def __init__(self, N, d_in, d_out):
|
||||
@@ -443,8 +531,8 @@ class QuantConv2d(QuantConvBase):
|
||||
_, inputs_h, inputs_w, n_channels_in = self.input_shape
|
||||
return weights_h * weights_w * n_channels_in
|
||||
|
||||
def forward(self, N=1):
|
||||
assert(N == 1)
|
||||
def forward(self, batch):
|
||||
assert len(batch) == 1
|
||||
assert(self.weight_shape[0] == self.output_shape[-1])
|
||||
|
||||
_, weights_h, weights_w, _ = self.weight_shape
|
||||
@@ -499,8 +587,8 @@ class QuantDepthwiseConv2d(QuantConvBase):
|
||||
_, weights_h, weights_w, _ = self.weight_shape
|
||||
return weights_h * weights_w
|
||||
|
||||
def forward(self, N=1):
|
||||
assert(N == 1)
|
||||
def forward(self, batch):
|
||||
assert len(batch) == 1
|
||||
assert(self.weight_shape[-1] == self.output_shape[-1])
|
||||
assert(self.input_shape[-1] == self.output_shape[-1])
|
||||
|
||||
@@ -562,8 +650,8 @@ class QuantAveragePool2d(QuantBase):
|
||||
for s in self.input_squant, self.output_squant:
|
||||
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
|
||||
|
||||
def forward(self, N=1):
|
||||
assert(N == 1)
|
||||
def forward(self, batch):
|
||||
assert len(batch) == 1
|
||||
|
||||
_, input_h, input_w, n_channels_in = self.input_shape
|
||||
_, output_h, output_w, n_channels_out = self.output_shape
|
||||
@@ -623,8 +711,8 @@ class QuantReshape(QuantBase):
|
||||
for i in range(2):
|
||||
sint.get_input_from(player)
|
||||
|
||||
def forward(self, N=1):
|
||||
assert(N == 1)
|
||||
def forward(self, batch):
|
||||
assert len(batch) == 1
|
||||
# reshaping is implicit
|
||||
self.Y.assign(self.X)
|
||||
|
||||
@@ -634,8 +722,8 @@ class QuantSoftmax(QuantBase):
|
||||
for s in self.input_squant, self.output_squant:
|
||||
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
|
||||
|
||||
def forward(self, N=1):
|
||||
assert(N == 1)
|
||||
def forward(self, batch):
|
||||
assert len(batch) == 1
|
||||
assert(len(self.input_shape) == 2)
|
||||
|
||||
# just print the best
|
||||
@@ -648,31 +736,40 @@ class QuantSoftmax(QuantBase):
|
||||
class Optimizer:
|
||||
n_threads = Layer.n_threads
|
||||
|
||||
def forward(self, N):
|
||||
def forward(self, N=None, batch=None):
|
||||
if batch is None:
|
||||
batch = regint.Array(N)
|
||||
batch.assign(regint.inc(N))
|
||||
for j in range(len(self.layers) - 1):
|
||||
self.layers[j].forward()
|
||||
self.layers[j + 1].X.assign(self.layers[j].Y)
|
||||
self.layers[-1].forward(N)
|
||||
self.layers[j].forward(batch=batch)
|
||||
tmp = self.layers[j].Y.get_part_vector(0, len(batch))
|
||||
self.layers[j + 1].X.assign_vector(tmp)
|
||||
self.layers[-1].forward(batch=batch)
|
||||
|
||||
def backward(self):
|
||||
def eval(self, data):
|
||||
N = len(data)
|
||||
self.layers[0].X.assign(data)
|
||||
self.forward(N)
|
||||
return self.layers[-1].eval(N)
|
||||
|
||||
def backward(self, batch):
|
||||
for j in range(1, len(self.layers)):
|
||||
self.layers[-j].backward()
|
||||
self.layers[-j - 1].nabla_Y.assign(self.layers[-j].nabla_X)
|
||||
self.layers[0].backward(compute_nabla_X=False)
|
||||
self.layers[-j].backward(batch=batch)
|
||||
self.layers[-j - 1].nabla_Y.assign_vector(
|
||||
self.layers[-j].nabla_X.get_part_vector(0, len(batch)))
|
||||
self.layers[0].backward(compute_nabla_X=False, batch=batch)
|
||||
|
||||
def run(self):
|
||||
def run(self, batch_size=None):
|
||||
if batch_size is not None:
|
||||
N = batch_size
|
||||
else:
|
||||
N = self.layers[0].N
|
||||
i = MemValue(0)
|
||||
@do_while
|
||||
def _():
|
||||
if self.X_by_label is not None:
|
||||
N = self.layers[0].N
|
||||
assert self.layers[-1].N == N
|
||||
assert N % 2 == 0
|
||||
n = N // 2
|
||||
@for_range(n)
|
||||
def _(i):
|
||||
self.layers[-1].Y[i] = 0
|
||||
self.layers[-1].Y[i + n] = 1
|
||||
n_per_epoch = int(math.ceil(1. * max(len(X) for X in
|
||||
self.X_by_label) / n))
|
||||
print('%d runs per epoch' % n_per_epoch)
|
||||
@@ -680,26 +777,27 @@ class Optimizer:
|
||||
for label, X in enumerate(self.X_by_label):
|
||||
indices = regint.Array(n * n_per_epoch)
|
||||
indices_by_label.append(indices)
|
||||
indices.assign(i % len(X) for i in range(len(indices)))
|
||||
indices.assign(regint.inc(len(indices), 0, 1, 1, len(X)))
|
||||
indices.shuffle()
|
||||
@for_range(n_per_epoch)
|
||||
def _(j):
|
||||
j = MemValue(j)
|
||||
batch = regint.Array(N)
|
||||
for label, X in enumerate(self.X_by_label):
|
||||
indices = indices_by_label[label]
|
||||
@for_range_multithread(self.n_threads, 1, n)
|
||||
def _(i):
|
||||
idx = indices[i + j * n]
|
||||
self.layers[0].X[i + label * n] = X[idx]
|
||||
self.forward(None)
|
||||
self.backward()
|
||||
batch.assign(indices.get_vector(j * n, n) +
|
||||
regint(label * len(self.X_by_label[0]), size=n),
|
||||
label * n)
|
||||
self.forward(batch=batch)
|
||||
self.backward(batch=batch)
|
||||
self.update(i)
|
||||
else:
|
||||
self.forward(None)
|
||||
self.backward()
|
||||
batch = regint.Array(N)
|
||||
batch.assign(regint.inc(N))
|
||||
self.forward(batch=batch)
|
||||
self.backward(batch=batch)
|
||||
self.update(i)
|
||||
loss = self.layers[-1].l
|
||||
if self.report_loss:
|
||||
if self.report_loss and not self.layers[-1].approx:
|
||||
print_ln('loss after epoch %s: %s', i, loss.reveal())
|
||||
else:
|
||||
print_ln('done with epoch %s', i)
|
||||
@@ -772,6 +870,13 @@ class SGD(Optimizer):
|
||||
|
||||
def reset(self, X_by_label=None):
|
||||
self.X_by_label = X_by_label
|
||||
if X_by_label is not None:
|
||||
for label, X in enumerate(X_by_label):
|
||||
@for_range_multithread(self.n_threads, 1, len(X))
|
||||
def _(i):
|
||||
j = i + label * len(X_by_label[0])
|
||||
self.layers[0].X[j] = X[i]
|
||||
self.layers[-1].Y[j] = label
|
||||
for y in self.delta_thetas:
|
||||
y.assign_all(0)
|
||||
for layer in self.layers:
|
||||
@@ -780,16 +885,15 @@ class SGD(Optimizer):
|
||||
def update(self, i_epoch):
|
||||
for nabla, theta, delta_theta in zip(self.nablas, self.thetas,
|
||||
self.delta_thetas):
|
||||
@for_range_opt_multithread(self.n_threads, len(nabla))
|
||||
def _(k):
|
||||
old = delta_theta[k]
|
||||
if isinstance(old, Array):
|
||||
old = old.get_vector()
|
||||
@multithread(self.n_threads, len(nabla))
|
||||
def _(base, size):
|
||||
old = delta_theta.get_vector(base, size)
|
||||
red_old = self.momentum * old
|
||||
new = self.gamma * nabla[k]
|
||||
new = self.gamma * nabla.get_vector(base, size)
|
||||
diff = red_old - new
|
||||
delta_theta[k] = diff
|
||||
theta[k] = theta[k] + delta_theta[k]
|
||||
delta_theta.assign_vector(diff, base)
|
||||
theta.assign_vector(theta.get_vector(base, size) +
|
||||
delta_theta.get_vector(base, size), base)
|
||||
if self.debug:
|
||||
for x, name in (old, 'old'), (red_old, 'red_old'), \
|
||||
(new, 'new'), (diff, 'diff'):
|
||||
|
||||
@@ -12,6 +12,8 @@ from Compiler import floatingpoint
|
||||
from Compiler import types
|
||||
from Compiler import comparison
|
||||
from Compiler import program
|
||||
from Compiler import instructions_base
|
||||
|
||||
# polynomials as enumerated on Hart's book
|
||||
##
|
||||
# @private
|
||||
@@ -154,8 +156,8 @@ def p_eval(p_c, x):
|
||||
def sTrigSub(x):
|
||||
# reduction to 2* \pi
|
||||
f = x * (1.0 / (2 * pi))
|
||||
f = load_sint(trunc(f), type(x))
|
||||
y = x - (f) * (2 * pi)
|
||||
f = trunc(f)
|
||||
y = x - (f) * x.coerce(2 * pi)
|
||||
# reduction to \pi
|
||||
b1 = y > pi
|
||||
w = b1 * ((2 * pi - y) - y) + y
|
||||
@@ -210,6 +212,7 @@ def scos(w, s):
|
||||
|
||||
# facade method calls --it is built in a generic way
|
||||
|
||||
@instructions_base.sfix_cisc
|
||||
def sin(x):
|
||||
"""
|
||||
Returns the sine of any given fractional value.
|
||||
@@ -224,6 +227,7 @@ def sin(x):
|
||||
return ssin(w, b1)
|
||||
|
||||
|
||||
@instructions_base.sfix_cisc
|
||||
def cos(x):
|
||||
"""
|
||||
Returns the cosine of any given fractional value.
|
||||
@@ -239,6 +243,7 @@ def cos(x):
|
||||
return scos(w, b2)
|
||||
|
||||
|
||||
@instructions_base.sfix_cisc
|
||||
def tan(x):
|
||||
"""
|
||||
Returns the tangent of any given fractional value.
|
||||
@@ -258,6 +263,7 @@ def tan(x):
|
||||
|
||||
|
||||
@types.vectorize
|
||||
@instructions_base.sfix_cisc
|
||||
def exp2_fx(a):
|
||||
"""
|
||||
Power of two for fixed-point numbers.
|
||||
@@ -273,39 +279,49 @@ def exp2_fx(a):
|
||||
n_int_bits = int(math.ceil(math.log(a.k - a.f, 2)))
|
||||
n_bits = a.f + n_int_bits
|
||||
n_shift = int(types.program.options.ring) - a.k
|
||||
r_bits = [sint.get_random_bit() for i in range(a.k)]
|
||||
shifted = ((a.v - sint.bit_compose(r_bits)) << n_shift).reveal()
|
||||
if types.program.use_edabit():
|
||||
l = sint.get_edabit(a.f, True)
|
||||
u = sint.get_edabit(a.k - a.f, True)
|
||||
r_bits = l[1] + u[1]
|
||||
r = l[0] + (u[0] << a.f)
|
||||
lower_r = l[0]
|
||||
else:
|
||||
r_bits = [sint.get_random_bit() for i in range(a.k)]
|
||||
r = sint.bit_compose(r_bits)
|
||||
lower_r = sint.bit_compose(r_bits[:a.f])
|
||||
shifted = ((a.v - r) << n_shift).reveal()
|
||||
masked_bits = (shifted >> n_shift).bit_decompose(a.k)
|
||||
lower_overflow = sint()
|
||||
comparison.CarryOut(lower_overflow, masked_bits[a.f-1::-1],
|
||||
lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1],
|
||||
r_bits[a.f-1::-1])
|
||||
lower_r = sint.bit_compose(r_bits[:a.f])
|
||||
lower_masked = sint.bit_compose(masked_bits[:a.f])
|
||||
lower = lower_r + lower_masked - (lower_overflow << (a.f))
|
||||
lower = lower_r + lower_masked - (sint.conv(lower_overflow) << (a.f))
|
||||
c = types.sfix._new(lower, k=a.k, f=a.f)
|
||||
higher_bits = intbitint.bit_adder(masked_bits[a.f:n_bits],
|
||||
r_bits[a.f:n_bits],
|
||||
higher_bits = r_bits[0].bit_adder(r_bits[a.f:n_bits],
|
||||
masked_bits[a.f:n_bits],
|
||||
carry_in=lower_overflow,
|
||||
get_carry=True)
|
||||
d = types.sfix.from_sint(floatingpoint.Pow2_from_bits(higher_bits[:-1]),
|
||||
k=a.k, f=a.f)
|
||||
assert(len(higher_bits) == n_bits - a.f + 1)
|
||||
pow2_bits = [sint.conv(x) for x in higher_bits]
|
||||
d = floatingpoint.Pow2_from_bits(pow2_bits[:-1])
|
||||
e = p_eval(p_1045, c)
|
||||
g = d * e
|
||||
small_result = types.sfix._new(g.v.round(a.k + 1, a.f, signed=False,
|
||||
small_result = types.sfix._new(g.v.round(a.f + 2 ** n_int_bits,
|
||||
2 ** n_int_bits, signed=False,
|
||||
nearest=types.sfix.round_nearest),
|
||||
k=a.k, f=a.f)
|
||||
carry = comparison.CarryOutLE(masked_bits[n_bits:-1],
|
||||
r_bits[n_bits:-1],
|
||||
higher_bits[-1])
|
||||
# should be for free
|
||||
highest_bits = intbitint.ripple_carry_adder(
|
||||
highest_bits = r_bits[0].ripple_carry_adder(
|
||||
masked_bits[n_bits:-1], [0] * (a.k - n_bits),
|
||||
carry_in=higher_bits[-1])
|
||||
bits_to_check = [x.bit_xor(y)
|
||||
for x, y in zip(highest_bits[:-1], r_bits[n_bits:-1])]
|
||||
t = floatingpoint.KMul(bits_to_check)
|
||||
t = sint.conv(floatingpoint.KOpL(lambda x, y: x.bit_and(y),
|
||||
bits_to_check))
|
||||
# sign
|
||||
s = masked_bits[-1].bit_xor(r_bits[-1]).bit_xor(carry)
|
||||
s = carry.bit_xor(sint.conv(r_bits[-1])).bit_xor(masked_bits[-1])
|
||||
return s.if_else(t.if_else(small_result, 0), g)
|
||||
else:
|
||||
# obtain absolute value of a
|
||||
@@ -313,16 +329,17 @@ def exp2_fx(a):
|
||||
a = (s * (-2) + 1) * a
|
||||
# isolates fractional part of number
|
||||
b = trunc(a)
|
||||
c = a - load_sint(b, type(a))
|
||||
c = a - b
|
||||
# squares integer part of a
|
||||
d = load_sint(b.pow2(types.sfix.k - types.sfix.f), type(a))
|
||||
d = b.pow2(a.k - a.f)
|
||||
# evaluates fractional part of a in p_1045
|
||||
e = p_eval(p_1045, c)
|
||||
g = d * e
|
||||
return (1 - s) * g + s * ((types.sfix(1)) / g)
|
||||
return (1 - s) * g + s / g
|
||||
|
||||
|
||||
@types.vectorize
|
||||
@instructions_base.sfix_cisc
|
||||
def log2_fx(x):
|
||||
"""
|
||||
Returns the result of :math:`\log_2(x)` for any unbounded
|
||||
@@ -347,14 +364,13 @@ def log2_fx(x):
|
||||
v, p, vlen = d.v, d.p, d.vlen
|
||||
# isolates mantisa of d, now the n can be also substituted by the
|
||||
# secret shared p from d in the expresion above.
|
||||
v = load_sint(v, type(x))
|
||||
w = (1.0 / (2 ** (vlen)))
|
||||
w = x.coerce(1.0 / (2 ** (vlen)))
|
||||
v = v * w
|
||||
# polynomials for the log_2 evaluation of f are calculated
|
||||
P = p_eval(p_2524, v)
|
||||
Q = p_eval(q_2524, v)
|
||||
# the log is returned by adding the result of the division plus p.
|
||||
a = P / Q + load_sint(vlen + p, type(x))
|
||||
a = P / Q + (vlen + p)
|
||||
return a # *(1-(f.z))*(1-f.s)*(1-f.error)
|
||||
|
||||
|
||||
@@ -515,7 +531,7 @@ def norm_simplified_SQ(b, k):
|
||||
# @return g: approximated sqrt
|
||||
def sqrt_simplified_fx(x):
|
||||
# fix theta (number of iterations)
|
||||
theta = max(int(math.ceil(math.log(types.sfix.k))), 6)
|
||||
theta = max(int(math.ceil(math.log(x.k))), 6)
|
||||
|
||||
# process to use 2^(m/2) approximation
|
||||
m_odd, m, w = norm_simplified_SQ(x.v, x.k)
|
||||
@@ -524,15 +540,15 @@ def sqrt_simplified_fx(x):
|
||||
m_odd = (1 - 2 * m_odd) + m_odd
|
||||
w = (w * 2 - w) * (1-m_odd) + w
|
||||
# map number to use sfix format and instantiate the number
|
||||
w = types.sfix(w * 2 ** ((x.f - (x.f % 2)) // 2))
|
||||
w = types.sfix(w * 2 ** ((x.f - (x.f % 2)) // 2), k=x.k, f=x.f)
|
||||
# obtains correct 2 ** (m/2)
|
||||
w = (w * (types.cfix(2 ** (1/2.0))) - w) * m_odd + w
|
||||
w = (w * (2 ** (1/2.0)) - w) * m_odd + w
|
||||
# produce x/ 2^(m/2)
|
||||
y_0 = types.cfix(1.0) / w
|
||||
y_0 = 1 / w
|
||||
|
||||
# from this point on it sufices to work sfix-wise
|
||||
g_0 = (y_0 * x)
|
||||
h_0 = y_0 * types.cfix(0.5)
|
||||
h_0 = y_0 * 0.5
|
||||
gh_0 = g_0 * h_0
|
||||
|
||||
## initialization
|
||||
@@ -689,7 +705,8 @@ def sqrt_fx(x_l, k, f):
|
||||
|
||||
|
||||
@types.vectorize
|
||||
def sqrt(x, k = types.sfix.k, f = types.sfix.f):
|
||||
@instructions_base.sfix_cisc
|
||||
def sqrt(x, k=None, f=None):
|
||||
"""
|
||||
Returns the square root (sfix) of any given fractional
|
||||
value as long as it can be rounded to a integral value
|
||||
@@ -699,7 +716,11 @@ def sqrt(x, k = types.sfix.k, f = types.sfix.f):
|
||||
|
||||
:return: square root of :py:obj:`x` (sfix).
|
||||
"""
|
||||
if (3 *k -2 * f >= types.sfix.f):
|
||||
if k is None:
|
||||
k = x.k
|
||||
if f is None:
|
||||
f = x.f
|
||||
if (3 *k -2 * f >= f):
|
||||
return sqrt_simplified_fx(x)
|
||||
# raise OverflowError("bound for precision violated: 3 * k - 2 * f < x.f ")
|
||||
else:
|
||||
@@ -707,6 +728,7 @@ def sqrt(x, k = types.sfix.k, f = types.sfix.f):
|
||||
return sqrt_fx(param ,k ,f)
|
||||
|
||||
|
||||
@instructions_base.sfix_cisc
|
||||
def atan(x):
|
||||
"""
|
||||
Returns the arctangent (sfix) of any given fractional value.
|
||||
@@ -720,12 +742,12 @@ def atan(x):
|
||||
x_abs = (s * (-2) + 1) * x
|
||||
# angle isolation
|
||||
b = x_abs > 1
|
||||
v = (types.cfix(1.0) / x_abs)
|
||||
v = 1 / x_abs
|
||||
v = (1 - b) * (x_abs - v) + v
|
||||
v_2 =v*v
|
||||
|
||||
# range of polynomial coefficients
|
||||
assert x.k - x.f >= 18
|
||||
assert x.k - x.f >= 15
|
||||
P = p_eval(p_5102, v_2)
|
||||
Q = p_eval(q_5102, v_2)
|
||||
|
||||
|
||||
@@ -57,10 +57,14 @@ class intBlock(Block):
|
||||
length = sum(self.lengths)
|
||||
self.n_bits = length * entries_per_block
|
||||
self.start = self.value_type.hard_conv(start * length)
|
||||
self.lower, self.shift = \
|
||||
floatingpoint.Trunc(self.value, self.n_bits, self.start, \
|
||||
if Program.prog.options.ring:
|
||||
self.lower, trunc, self.shift = floatingpoint.SplitInRing(
|
||||
self.value, self.n_bits, self.start)
|
||||
else:
|
||||
self.lower, self.shift = \
|
||||
floatingpoint.Trunc(self.value, self.n_bits, self.start, \
|
||||
Program.prog.security, True)
|
||||
trunc = (self.value - self.lower) / self.shift
|
||||
trunc = (self.value - self.lower) / self.shift
|
||||
self.slice = trunc.mod2m(length, self.n_bits, False)
|
||||
self.upper = (trunc - self.slice) * self.shift
|
||||
def get_slice(self):
|
||||
@@ -810,7 +814,7 @@ def get_n_threads(n_loops):
|
||||
if n_loops > 2048:
|
||||
return 8
|
||||
else:
|
||||
return 1
|
||||
return None
|
||||
else:
|
||||
return n_threads
|
||||
|
||||
@@ -1375,7 +1379,11 @@ def get_value_size(value_type):
|
||||
if value_type == sgf2n:
|
||||
return Program.prog.galois_length
|
||||
elif value_type == sint:
|
||||
return 127 - Program.prog.security
|
||||
ring = Program.prog.options.ring
|
||||
if ring:
|
||||
return int(ring)
|
||||
else:
|
||||
return 127 - Program.prog.security
|
||||
else:
|
||||
return value_type.max_length
|
||||
|
||||
@@ -1477,11 +1485,13 @@ class PackedIndexStructure(object):
|
||||
rem = mod2m(index, self.log_entries_per_block, log2(self.size), False)
|
||||
c = mod2m(rem, self.log_entries_per_element, \
|
||||
self.log_entries_per_block, False)
|
||||
b = (rem - c) / self.entries_per_element
|
||||
b = (rem - c).trunc_zeros(self.log_entries_per_element,
|
||||
self.log_entries_per_block)
|
||||
if self.small:
|
||||
return 0, b, c
|
||||
else:
|
||||
return (index - rem) / self.entries_per_block, b, c
|
||||
return (index - rem).trunc_zeros(self.log_entries_per_block,
|
||||
log2(self.size)), b, c
|
||||
else:
|
||||
index_bits = bit_decompose(index, log2(self.size))
|
||||
l1 = self.log_entries_per_element
|
||||
|
||||
@@ -301,7 +301,8 @@ def iter_waksman(a, config, reverse=False):
|
||||
conf_address = MemValue(config.address + depth.read()*n)
|
||||
do_round(size, conf_address, a.address, a2.address, 1)
|
||||
|
||||
for i in range(n):
|
||||
@for_range(n)
|
||||
def _(i):
|
||||
a[i] = a2[i]
|
||||
|
||||
nblocks.write(nblocks*2)
|
||||
@@ -317,7 +318,8 @@ def iter_waksman(a, config, reverse=False):
|
||||
conf_address = MemValue(config.address + depth.read()*n)
|
||||
do_round(size, conf_address, a.address, a2.address, 0)
|
||||
|
||||
for i in range(n):
|
||||
@for_range(n)
|
||||
def _(i):
|
||||
a[i] = a2[i]
|
||||
|
||||
nblocks.write(nblocks//2)
|
||||
@@ -379,6 +381,14 @@ def config_shuffle(n, value_type):
|
||||
config_bits = configure_waksman(perm)
|
||||
# 2-D array
|
||||
config = Array(len(config_bits) * len(perm), value_type.reg_type)
|
||||
if n > 1024:
|
||||
for x in config_bits:
|
||||
for y in x:
|
||||
get_program().public_input(y)
|
||||
@for_range(sum(len(x) for x in config_bits))
|
||||
def _(i):
|
||||
config[i] = public_input()
|
||||
return config
|
||||
for i,c in enumerate(config_bits):
|
||||
for j,b in enumerate(c):
|
||||
config[i * len(perm) + j] = b
|
||||
|
||||
@@ -22,12 +22,14 @@ data_types = dict(
|
||||
bit = 2,
|
||||
inverse = 3,
|
||||
bittriple = 4,
|
||||
bitgf2ntriple = 5
|
||||
bitgf2ntriple = 5,
|
||||
dabit = 6,
|
||||
)
|
||||
|
||||
field_types = dict(
|
||||
modp = 0,
|
||||
gf2n = 1,
|
||||
bit = 2,
|
||||
)
|
||||
|
||||
|
||||
@@ -39,6 +41,7 @@ class Program(object):
|
||||
and threads. """
|
||||
def __init__(self, args, options, param=-1, assemblymode=False):
|
||||
self.options = options
|
||||
self.verbose = options.verbose
|
||||
self.args = args
|
||||
self.init_names(args, assemblymode)
|
||||
self.P = P_VALUES[param]
|
||||
@@ -56,7 +59,8 @@ class Program(object):
|
||||
self.security = 40
|
||||
print('Default security parameter:', self.security)
|
||||
self.galois_length = int(options.galois)
|
||||
print('Galois length:', self.galois_length)
|
||||
if self.verbose:
|
||||
print('Galois length:', self.galois_length)
|
||||
self.schedule = [('start', [])]
|
||||
self.tape_counter = 0
|
||||
self.tapes = []
|
||||
@@ -73,9 +77,9 @@ class Program(object):
|
||||
self.n_threads = 1
|
||||
self.free_threads = set()
|
||||
self.public_input_file = open(self.programs_dir + '/Public-Input/%s' % self.name, 'w')
|
||||
self.use_public_input_file = False
|
||||
self.types = {}
|
||||
self.budget = int(self.options.budget)
|
||||
self.verbose = False
|
||||
self.to_merge = [Compiler.instructions.asm_open_class, \
|
||||
Compiler.instructions.gasm_open_class, \
|
||||
Compiler.instructions.muls_class, \
|
||||
@@ -89,12 +93,15 @@ class Program(object):
|
||||
Compiler.instructions.inputfix_class,
|
||||
Compiler.instructions.inputfloat_class,
|
||||
Compiler.instructions.inputmixed_class,
|
||||
Compiler.instructions.trunc_pr_class]
|
||||
Compiler.instructions.trunc_pr_class,
|
||||
Compiler.instructions_base.Mergeable]
|
||||
import Compiler.GC.instructions as gc
|
||||
self.to_merge += [gc.ldmsdi, gc.stmsdi, gc.ldmsd, gc.stmsd, \
|
||||
gc.stmsdci, gc.xors, gc.andrs, gc.ands, gc.inputb]
|
||||
self.use_trunc_pr = False
|
||||
self.use_dabit = options.mixed
|
||||
self._edabit = options.edabit
|
||||
self._split = False
|
||||
Program.prog = self
|
||||
|
||||
self.reset_values()
|
||||
@@ -132,7 +139,8 @@ class Program(object):
|
||||
else:
|
||||
# assume source is in main SPDZ directory
|
||||
self.programs_dir = sys.path[0] + '/Programs'
|
||||
print('Compiling program in', self.programs_dir)
|
||||
if self.verbose:
|
||||
print('Compiling program in', self.programs_dir)
|
||||
|
||||
# create extra directories if needed
|
||||
for dirname in ['Public-Input', 'Bytecode', 'Schedules']:
|
||||
@@ -161,7 +169,7 @@ class Program(object):
|
||||
self.name += '-' + '-'.join(args[1:])
|
||||
self.progname = progname
|
||||
|
||||
def new_tape(self, function, args=[], name=None):
|
||||
def new_tape(self, function, args=[], name=None, single_thread=False):
|
||||
if name is None:
|
||||
name = function.__name__
|
||||
name = "%s-%s" % (self.name, name)
|
||||
@@ -170,7 +178,7 @@ class Program(object):
|
||||
tape_index = len(self.tapes)
|
||||
self.tape_stack.append(self.curr_tape)
|
||||
self.curr_tape = Tape(name, self)
|
||||
self.curr_tape.prevent_direct_memory_write = True
|
||||
self.curr_tape.prevent_direct_memory_write = not single_thread
|
||||
self.tapes.append(self.curr_tape)
|
||||
function(*args)
|
||||
self.finalize_tape(self.curr_tape)
|
||||
@@ -183,7 +191,8 @@ class Program(object):
|
||||
raise CompilerError('Compiler does not support ' \
|
||||
'recursive spawning of threads')
|
||||
if self.free_threads:
|
||||
thread_number = self.free_threads.pop()
|
||||
thread_number = min(self.free_threads)
|
||||
self.free_threads.remove(thread_number)
|
||||
else:
|
||||
thread_number = self.n_threads
|
||||
self.n_threads += 1
|
||||
@@ -376,7 +385,7 @@ class Program(object):
|
||||
else:
|
||||
addr = self.allocated_mem[mem_type]
|
||||
self.allocated_mem[mem_type] += size
|
||||
if len(str(addr)) != len(str(addr + size)):
|
||||
if len(str(addr)) != len(str(addr + size)) and self.verbose:
|
||||
print("Memory of type '%s' now of size %d" % (mem_type, addr + size))
|
||||
self.allocated_mem_blocks[addr,mem_type] = size
|
||||
return addr
|
||||
@@ -404,6 +413,7 @@ class Program(object):
|
||||
|
||||
def public_input(self, x):
|
||||
self.public_input_file.write('%s\n' % str(x))
|
||||
self.use_public_input_file = True
|
||||
|
||||
def set_bit_length(self, bit_length):
|
||||
self.bit_length = bit_length
|
||||
@@ -421,6 +431,22 @@ class Program(object):
|
||||
self.tape_counter += 1
|
||||
return res
|
||||
|
||||
def use_edabit(self, change=None):
|
||||
if change is None:
|
||||
return self._edabit
|
||||
else:
|
||||
self._edabit = change
|
||||
|
||||
def use_edabit_for(self, *args):
|
||||
return True
|
||||
|
||||
def use_split(self, change=None):
|
||||
if change is None:
|
||||
return self._split
|
||||
else:
|
||||
assert change in (2, 3)
|
||||
self._split = change
|
||||
|
||||
class Tape:
|
||||
""" A tape contains a list of basic blocks, onto which instructions are added. """
|
||||
def __init__(self, name, program):
|
||||
@@ -503,13 +529,17 @@ class Tape:
|
||||
self.exit_condition.set_relative_jump(offset)
|
||||
#print 'Basic block %d jumps to %d (%d)' % (next_block_index, jump_index, offset)
|
||||
|
||||
def purge(self):
|
||||
def purge(self, retain_usage=True):
|
||||
def relevant(inst):
|
||||
req_node = Tape.ReqNode('')
|
||||
req_node.num = Tape.ReqNum()
|
||||
inst.add_usage(req_node)
|
||||
return req_node.num != {}
|
||||
self.usage_instructions = list(filter(relevant, self.instructions))
|
||||
if retain_usage:
|
||||
self.usage_instructions = list(filter(relevant,
|
||||
self.instructions))
|
||||
else:
|
||||
self.usage_instructions = []
|
||||
if len(self.usage_instructions) > 1000:
|
||||
print('Retaining %d instructions' % len(self.usage_instructions))
|
||||
del self.instructions
|
||||
@@ -526,6 +556,12 @@ class Tape:
|
||||
req_node.num['all', 'round'] = self.n_rounds
|
||||
req_node.num['all', 'inv'] = self.n_to_merge
|
||||
|
||||
def expand_cisc(self):
|
||||
new_instructions = []
|
||||
for inst in self.instructions:
|
||||
new_instructions.extend(inst.expand_merged())
|
||||
self.instructions = new_instructions
|
||||
|
||||
def __str__(self):
|
||||
return self.name
|
||||
|
||||
@@ -577,8 +613,9 @@ class Tape:
|
||||
def unpurged(function):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if self.purged:
|
||||
print('%s called on purged block %s, ignoring' % \
|
||||
(function.__name__, self.name))
|
||||
if self.program.verbose:
|
||||
print('%s called on purged block %s, ignoring' % \
|
||||
(function.__name__, self.name))
|
||||
return
|
||||
return function(self, *args, **kwargs)
|
||||
return wrapper
|
||||
@@ -592,7 +629,8 @@ class Tape:
|
||||
if self.if_states:
|
||||
raise CompilerError('Unclosed if/else blocks')
|
||||
|
||||
print('Processing tape', self.name, 'with %d blocks' % len(self.basicblocks))
|
||||
if self.program.verbose:
|
||||
print('Processing tape', self.name, 'with %d blocks' % len(self.basicblocks))
|
||||
|
||||
for block in self.basicblocks:
|
||||
al.determine_scope(block, options)
|
||||
@@ -601,7 +639,7 @@ class Tape:
|
||||
# need to do this if there are several blocks
|
||||
if (options.merge_opens and self.merge_opens) or options.dead_code_elimination:
|
||||
for i,block in enumerate(self.basicblocks):
|
||||
if len(block.instructions) > 0:
|
||||
if len(block.instructions) > 0 and self.program.verbose:
|
||||
print('Processing basic block %s, %d/%d, %d instructions' % \
|
||||
(block.name, i, len(self.basicblocks), \
|
||||
len(block.instructions)))
|
||||
@@ -609,7 +647,7 @@ class Tape:
|
||||
merger = al.Merger(block, options, \
|
||||
tuple(self.program.to_merge))
|
||||
if options.dead_code_elimination:
|
||||
if len(block.instructions) > 10000:
|
||||
if len(block.instructions) > 100000:
|
||||
print('Eliminate dead code...')
|
||||
merger.eliminate_dead_code()
|
||||
if options.merge_opens and self.merge_opens:
|
||||
@@ -617,14 +655,14 @@ class Tape:
|
||||
block.used_from_scope = util.set_by_id()
|
||||
block.defined_registers = util.set_by_id()
|
||||
continue
|
||||
if len(block.instructions) > 10000:
|
||||
if len(block.instructions) > 100000:
|
||||
print('Merging instructions...')
|
||||
numrounds = merger.longest_paths_merge()
|
||||
block.n_rounds = numrounds
|
||||
block.n_to_merge = len(merger.open_nodes)
|
||||
if numrounds > 0:
|
||||
if numrounds > 0 and self.program.verbose:
|
||||
print('Program requires %d rounds of communication' % numrounds)
|
||||
if merger.counter:
|
||||
if merger.counter and self.program.verbose:
|
||||
print('Block requires', \
|
||||
', '.join('%d %s' % (y, x.__name__) \
|
||||
for x, y in list(merger.counter.items())))
|
||||
@@ -635,6 +673,9 @@ class Tape:
|
||||
if not (options.merge_opens and self.merge_opens):
|
||||
print('Not merging instructions in tape %s' % self.name)
|
||||
|
||||
if options.cisc:
|
||||
self.expand_cisc()
|
||||
|
||||
# add jumps
|
||||
offset = 0
|
||||
for block in self.basicblocks:
|
||||
@@ -659,7 +700,7 @@ class Tape:
|
||||
print('modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp]))
|
||||
print('GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N]))
|
||||
print('Re-allocating...')
|
||||
allocator = al.StraightlineAllocator(REG_MAX)
|
||||
allocator = al.StraightlineAllocator(REG_MAX, self.program)
|
||||
def alloc(block):
|
||||
for reg in sorted(block.used_from_scope,
|
||||
key=lambda x: (x.reg_type, x.i)):
|
||||
@@ -673,7 +714,7 @@ class Tape:
|
||||
if child.instructions:
|
||||
left.append(child)
|
||||
for i,block in enumerate(reversed(self.basicblocks)):
|
||||
if len(block.instructions) > 10000:
|
||||
if len(block.instructions) > 100000:
|
||||
print('Allocating %s, %d/%d' % \
|
||||
(block.name, i, len(self.basicblocks)))
|
||||
if block.exit_condition is not None:
|
||||
@@ -684,9 +725,11 @@ class Tape:
|
||||
allocator.process(block.instructions, block.alloc_pool)
|
||||
|
||||
# offline data requirements
|
||||
print('Compile offline data requirements...')
|
||||
if self.program.verbose:
|
||||
print('Compile offline data requirements...')
|
||||
self.req_num = self.req_tree.aggregate()
|
||||
print('Tape requires', self.req_num)
|
||||
if self.program.verbose:
|
||||
print('Tape requires', self.req_num)
|
||||
for req,num in sorted(self.req_num.items()):
|
||||
if num == float('inf') or num >= 2 ** 32:
|
||||
num = -1
|
||||
@@ -708,6 +751,14 @@ class Tape:
|
||||
self.basicblocks[-1].instructions.append(
|
||||
Compiler.instructions.guse_prep(req[1], num, \
|
||||
add_to_prog=False))
|
||||
elif req[0] == 'edabit':
|
||||
self.basicblocks[-1].instructions.append(
|
||||
Compiler.instructions.use_edabit(False, req[1], num, \
|
||||
add_to_prog=False))
|
||||
elif req[0] == 'sedabit':
|
||||
self.basicblocks[-1].instructions.append(
|
||||
Compiler.instructions.use_edabit(True, req[1], num, \
|
||||
add_to_prog=False))
|
||||
|
||||
if not self.is_empty():
|
||||
# bit length requirement
|
||||
@@ -723,6 +774,11 @@ class Tape:
|
||||
print('Tape requires prime bit length', self.req_bit_length['p'])
|
||||
print('Tape requires galois bit length', self.req_bit_length['2'])
|
||||
|
||||
@unpurged
|
||||
def expand_cisc(self):
|
||||
for block in self.basicblocks:
|
||||
block.expand_cisc()
|
||||
|
||||
@unpurged
|
||||
def _get_instructions(self):
|
||||
return itertools.chain.\
|
||||
@@ -786,7 +842,7 @@ class Tape:
|
||||
|
||||
def reset_registers(self):
|
||||
""" Reset register values to zero. """
|
||||
self.reg_values = RegType.create_dict(lambda: [0] * INIT_REG_MAX)
|
||||
self.reg_values = RegType.create_dict(lambda: [])
|
||||
|
||||
def get_value(self, reg_type, i):
|
||||
return self.reg_values[reg_type][i]
|
||||
@@ -826,7 +882,7 @@ class Tape:
|
||||
return res
|
||||
def cost(self):
|
||||
return sum(num * COST[req[0]][req[1]] for req,num in list(self.items()) \
|
||||
if req[1] != 'input')
|
||||
if req[1] != 'input' and req[0] != 'edabit')
|
||||
def __str__(self):
|
||||
return ", ".join('%s inputs in %s from player %d' \
|
||||
% (num, req[0], req[2]) \
|
||||
@@ -927,9 +983,11 @@ class Tape:
|
||||
self.relative_i = 0
|
||||
if i is not None:
|
||||
self.i = i
|
||||
else:
|
||||
elif size is not None:
|
||||
self.i = program.reg_counter[reg_type]
|
||||
program.reg_counter[reg_type] += size
|
||||
else:
|
||||
self.i = float('inf')
|
||||
self.vector = []
|
||||
if value is not None:
|
||||
self.value = value
|
||||
@@ -952,8 +1010,9 @@ class Tape:
|
||||
def set_size(self, size):
|
||||
if self.size == size:
|
||||
return
|
||||
elif not self.program.options.assemblymode:
|
||||
raise CompilerError('Mismatch of instruction and register size')
|
||||
elif not self.program.program.options.assemblymode:
|
||||
raise CompilerError('Mismatch of instruction and register size:'
|
||||
' %s != %s' % (self.size, size))
|
||||
elif self.size == 1 and self.vectorbase is self:
|
||||
if '%s%d' % (self.reg_type, self.i) in compilerLib.VARS:
|
||||
# create vector register in assembly mode
|
||||
@@ -979,6 +1038,10 @@ class Tape:
|
||||
return Tape.Register(self.reg_type, self.program, size=size, i=i)
|
||||
|
||||
def get_vector(self, base, size):
|
||||
if base == 0 and size == self.size:
|
||||
return self
|
||||
if size == 1:
|
||||
return self[base]
|
||||
res = self._new_by_number(self.i + base, size=size)
|
||||
res.set_vectorbase(self)
|
||||
self.create_vector_elements()
|
||||
@@ -1007,7 +1070,10 @@ class Tape:
|
||||
|
||||
def __len__(self):
|
||||
return self.size
|
||||
|
||||
|
||||
def copy(self):
|
||||
return Tape.Register(self.reg_type, Program.prog.curr_tape)
|
||||
|
||||
@property
|
||||
def value(self):
|
||||
return self.program.reg_values[self.reg_type][self.i]
|
||||
@@ -1029,6 +1095,9 @@ class Tape:
|
||||
self.reg_type == RegType.ClearGF2N or \
|
||||
self.reg_type == RegType.ClearInt
|
||||
|
||||
def __bool__(self):
|
||||
raise CompilerError('cannot derive truth value from register')
|
||||
|
||||
def __str__(self):
|
||||
return self.reg_type + str(self.i)
|
||||
|
||||
|
||||
@@ -95,7 +95,8 @@ class ClientMessageType:
|
||||
|
||||
|
||||
class MPCThread(object):
|
||||
def __init__(self, target, name, args = [], runtime_arg = None):
|
||||
def __init__(self, target, name, args = [], runtime_arg = 0,
|
||||
single_thread = False):
|
||||
""" Create a thread from a callable object. """
|
||||
if not callable(target):
|
||||
raise CompilerError('Target %s for thread %s is not callable' % (target,name))
|
||||
@@ -105,18 +106,33 @@ class MPCThread(object):
|
||||
self.args = args
|
||||
self.runtime_arg = runtime_arg
|
||||
self.running = 0
|
||||
self.tape_handle = program.new_tape(target, args, name,
|
||||
single_thread=single_thread)
|
||||
self.run_handles = []
|
||||
|
||||
def start(self, runtime_arg = None):
|
||||
self.running += 1
|
||||
program.start_thread(self, runtime_arg or self.runtime_arg)
|
||||
self.run_handles.append(program.run_tape(self.tape_handle, \
|
||||
runtime_arg or self.runtime_arg))
|
||||
|
||||
def join(self):
|
||||
if not self.running:
|
||||
raise CompilerError('Thread %s is not running' % self.name)
|
||||
self.running -= 1
|
||||
program.stop_thread(self)
|
||||
program.join_tape(self.run_handles.pop(0))
|
||||
|
||||
|
||||
def copy_doc(a, b):
|
||||
try:
|
||||
a.__doc__ = b.__doc__
|
||||
except:
|
||||
pass
|
||||
|
||||
def no_doc(operation):
|
||||
def wrapper(*args, **kwargs):
|
||||
return operation(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
def copy_doc(a, b):
|
||||
try:
|
||||
a.__doc__ = b.__doc__
|
||||
@@ -131,7 +147,9 @@ def no_doc(operation):
|
||||
def vectorize(operation):
|
||||
def vectorized_operation(self, *args, **kwargs):
|
||||
if len(args):
|
||||
from .GC.types import bits
|
||||
if (isinstance(args[0], Tape.Register) or isinstance(args[0], sfloat)) \
|
||||
and not isinstance(args[0], bits) \
|
||||
and args[0].size != self.size:
|
||||
raise CompilerError('Different vector sizes of operands')
|
||||
set_global_vector_size(self.size)
|
||||
@@ -292,6 +310,10 @@ class _int(object):
|
||||
def bit_adder(*args, **kwargs):
|
||||
return intbitint.bit_adder(*args, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def ripple_carry_adder(*args, **kwargs):
|
||||
return intbitint.ripple_carry_adder(*args, **kwargs)
|
||||
|
||||
def if_else(self, a, b):
|
||||
""" MUX on bit in arithmetic circuits.
|
||||
|
||||
@@ -416,6 +438,9 @@ class _structure(object):
|
||||
res_params) \
|
||||
for k in range(len(row))).reduce_after_mul()
|
||||
|
||||
class _vec(object):
|
||||
pass
|
||||
|
||||
class _register(Tape.Register, _number, _structure):
|
||||
@staticmethod
|
||||
def n_elements():
|
||||
@@ -427,7 +452,7 @@ class _register(Tape.Register, _number, _structure):
|
||||
val = val.read()
|
||||
if isinstance(val, cls):
|
||||
return val
|
||||
elif not isinstance(val, _register):
|
||||
elif not isinstance(val, (_register, _vec)):
|
||||
try:
|
||||
return type(val)(cls.conv(v) for v in val)
|
||||
except TypeError:
|
||||
@@ -467,8 +492,7 @@ class _register(Tape.Register, _number, _structure):
|
||||
address = regint.conv(address)
|
||||
if size > 1 and address.size == 1:
|
||||
res = regint(size=size)
|
||||
for i in range(size):
|
||||
movint(res[i], address + regint(i, size=1))
|
||||
incint(res, address, 1)
|
||||
return res
|
||||
else:
|
||||
return address
|
||||
@@ -1089,6 +1113,29 @@ class regint(_register, _int):
|
||||
rand(res, bit_length)
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def inc(cls, size, base=0, step=1, repeat=1, wrap=None):
|
||||
"""
|
||||
Produce :py:class:`regint` vector with certain patterns.
|
||||
This is particularly useful for :py:meth:`SubMultiArray.direct_mul`.
|
||||
|
||||
:param size: Result size
|
||||
:param base: First value
|
||||
:param step: Increase step
|
||||
:param repeat: Repeate this many times
|
||||
:param wrap: Start over after this many increases
|
||||
|
||||
The following produces (1, 1, 1, 3, 3, 3, 5, 5, 5, 7)::
|
||||
|
||||
regint.inc(10, 1, 2, 3)
|
||||
|
||||
"""
|
||||
res = regint(size=size)
|
||||
if wrap is None:
|
||||
wrap = size
|
||||
incint(res, cls.conv(base, size=1), step, repeat, wrap)
|
||||
return res
|
||||
|
||||
@vectorized_classmethod
|
||||
def read_from_socket(cls, client_id, n=1):
|
||||
""" Receive n register values from socket """
|
||||
@@ -1333,6 +1380,12 @@ class regint(_register, _int):
|
||||
res += bit
|
||||
return res
|
||||
|
||||
def shuffle(self):
|
||||
""" Returns insecure shuffle of vector. """
|
||||
res = regint(size=len(self))
|
||||
shuffle(res, self)
|
||||
return res
|
||||
|
||||
def reveal(self):
|
||||
""" Identity. """
|
||||
return self
|
||||
@@ -1371,7 +1424,7 @@ class localint(object):
|
||||
class _secret(_register):
|
||||
__slots__ = []
|
||||
|
||||
mov = staticmethod(movs)
|
||||
mov = staticmethod(set_instruction_type(movs))
|
||||
PreOR = staticmethod(lambda l: floatingpoint.PreORC(l))
|
||||
PreOp = staticmethod(lambda op, l: floatingpoint.PreOpL(op, l))
|
||||
|
||||
@@ -1475,9 +1528,7 @@ class _secret(_register):
|
||||
res = cls(size=size)
|
||||
n_rows = len(A) // n
|
||||
n_cols = len(B) // n
|
||||
dotprods(*sum(([res[j], [A[j // n_cols * n + k] for k in range(n)],
|
||||
[B[k * n_cols + j % n_cols] for k in range(n)]]
|
||||
for j in range(size)), []))
|
||||
matmuls(res, A, B, n_rows, n, n_cols)
|
||||
return res
|
||||
|
||||
@no_doc
|
||||
@@ -1502,7 +1553,7 @@ class _secret(_register):
|
||||
@read_mem_value
|
||||
@vectorize
|
||||
def load_other(self, val):
|
||||
from Compiler.GC.types import sbits
|
||||
from Compiler.GC.types import sbits, sbitvec
|
||||
if isinstance(val, self.clear_type):
|
||||
self.load_clear(val)
|
||||
elif isinstance(val, type(self)):
|
||||
@@ -1510,9 +1561,16 @@ class _secret(_register):
|
||||
elif isinstance(val, sbits):
|
||||
assert(val.n == self.size)
|
||||
r = self.get_dabit()
|
||||
v = regint()
|
||||
bitdecint_class(regint((r[1] ^ val).reveal()), *v)
|
||||
movs(self, r[0].bit_xor(v))
|
||||
movs(self, r[0].bit_xor((r[1] ^ val).reveal().to_regint_by_bit()))
|
||||
elif isinstance(val, sbitvec):
|
||||
assert(sum(x.n for x in val.v) == self.size)
|
||||
for val_part, base in zip(val, range(0, self.size, 64)):
|
||||
left = min(64, self.size - base)
|
||||
r = self.get_dabit(size=left)
|
||||
v = regint(size=left)
|
||||
bitdecint_class(regint((r[1] ^ val_part).reveal()), *v)
|
||||
part = r[0].bit_xor(v)
|
||||
vmovs(left, self.get_vector(base, left), part)
|
||||
else:
|
||||
self.load_clear(self.clear_type(val))
|
||||
|
||||
@@ -1555,8 +1613,8 @@ class _secret(_register):
|
||||
size or one size 1 for a value-vector multiplication.
|
||||
|
||||
:param other: any compatible type """
|
||||
if isinstance(other, _secret) and max(self.size, other.size) > 1 \
|
||||
and min(self.size, other.size) == 1:
|
||||
if isinstance(other, _secret) and (1 in (self.size, other.size)) \
|
||||
and (self.size, other.size) != (1, 1):
|
||||
x, y = (other, self) if self.size < other.size else (self, other)
|
||||
res = type(self)(size=x.size)
|
||||
mulrs(res, x, y)
|
||||
@@ -1667,10 +1725,35 @@ class sint(_secret, _int):
|
||||
dabit(*res)
|
||||
return res
|
||||
|
||||
@vectorized_classmethod
|
||||
def get_edabit(cls, n_bits, strict=False):
|
||||
""" Bits in arithmetic and binary circuit """
|
||||
""" according to security model """
|
||||
if not program.use_edabit_for(strict, n_bits):
|
||||
if program.use_dabit:
|
||||
a, b = zip(*(sint.get_dabit() for i in range(n_bits)))
|
||||
return sint.bit_compose(a), b
|
||||
else:
|
||||
a = [sint.get_random_bit() for i in range(n_bits)]
|
||||
return sint.bit_compose(a), a
|
||||
whole = cls()
|
||||
size = get_global_vector_size()
|
||||
from Compiler.GC.types import sbits, sbitvec
|
||||
bits = [sbits.get_type(size)() for i in range(n_bits)]
|
||||
if strict:
|
||||
sedabit(whole, *bits)
|
||||
else:
|
||||
edabit(whole, *bits)
|
||||
return whole, bits
|
||||
|
||||
@staticmethod
|
||||
def long_one():
|
||||
return 1
|
||||
|
||||
@staticmethod
|
||||
def bit_decompose_clear(a, n_bits):
|
||||
return floatingpoint.bits(a, n_bits)
|
||||
|
||||
@classmethod
|
||||
def get_raw_input_from(cls, player):
|
||||
res = cls()
|
||||
@@ -1733,6 +1816,15 @@ class sint(_secret, _int):
|
||||
""" Store in memory by public address. """
|
||||
self._store_in_mem(address, stms, stmsi)
|
||||
|
||||
@classmethod
|
||||
def direct_matrix_mul(cls, A, B, n, m, l, reduce=None, indices=None):
|
||||
if indices is None:
|
||||
indices = [regint.inc(i) for i in (n, m, m, l)]
|
||||
res = cls(size=indices[0].size * indices[3].size)
|
||||
matmulsm(res, regint(A), regint(B), len(indices[0]), len(indices[1]),
|
||||
len(indices[3]), *(list(indices) + [m, l]))
|
||||
return res
|
||||
|
||||
def __init__(self, val=None, size=None):
|
||||
"""
|
||||
:param val: initialization (sint/cint/regint/int/cgf2n or list thereof)
|
||||
@@ -1810,6 +1902,7 @@ class sint(_secret, _int):
|
||||
return self.mod2m(int(l))
|
||||
raise NotImplementedError('Modulo only implemented for powers of two.')
|
||||
|
||||
@vectorize
|
||||
@read_mem_value
|
||||
def mod2m(self, m, bit_length=None, security=None, signed=True):
|
||||
""" Secret modulo power of two.
|
||||
@@ -1941,6 +2034,10 @@ class sint(_secret, _int):
|
||||
comparison.Trunc(res, tmp, 2 * k, k, kappa, True)
|
||||
return res
|
||||
|
||||
def trunc_zeros(self, n_zeros, bit_length=None, signed=True):
|
||||
bit_length = bit_length or program.bit_length
|
||||
return comparison.TruncZeros(self, bit_length, n_zeros, signed)
|
||||
|
||||
@staticmethod
|
||||
def two_power(n):
|
||||
return floatingpoint.two_power(n)
|
||||
@@ -2120,11 +2217,14 @@ class _bitint(object):
|
||||
@classmethod
|
||||
def bit_adder_selection(cls, a, b, carry_in=0, get_carry=False):
|
||||
if cls.log_rounds:
|
||||
return cls.carry_lookahead_adder(a, b, carry_in=carry_in)
|
||||
return cls.carry_lookahead_adder(a, b, carry_in=carry_in,
|
||||
get_carry=get_carry)
|
||||
elif cls.linear_rounds:
|
||||
return cls.ripple_carry_adder(a, b, carry_in=carry_in)
|
||||
return cls.ripple_carry_adder(a, b, carry_in=carry_in,
|
||||
get_carry=get_carry)
|
||||
else:
|
||||
return cls.carry_select_adder(a, b, carry_in=carry_in)
|
||||
return cls.carry_select_adder(a, b, carry_in=carry_in,
|
||||
get_carry=get_carry)
|
||||
|
||||
@classmethod
|
||||
def carry_lookahead_adder(cls, a, b, fewer_inv=False, carry_in=0,
|
||||
@@ -2205,8 +2305,8 @@ class _bitint(object):
|
||||
|
||||
@staticmethod
|
||||
def full_adder(a, b, carry):
|
||||
s = a + b
|
||||
return s + carry, util.if_else(s, carry, a)
|
||||
s = a ^ b
|
||||
return s ^ carry, a ^ (s & (carry ^ a))
|
||||
|
||||
@staticmethod
|
||||
def bit_comparator(a, b):
|
||||
@@ -2243,6 +2343,7 @@ class _bitint(object):
|
||||
b = util.bit_decompose(other, self.n_bits)
|
||||
return self.compose(self.bit_adder(a, b))
|
||||
|
||||
@ret_cisc
|
||||
def mul(self, other):
|
||||
if type(other) == self.bin_type:
|
||||
raise CompilerError('Unclear multiplication')
|
||||
@@ -2270,12 +2371,13 @@ class _bitint(object):
|
||||
@classmethod
|
||||
def wallace_tree_from_matrix(cls, bit_matrix, get_carry=True):
|
||||
columns = [[_f for _f in (bit_matrix[j][i-j] \
|
||||
for j in range(min(len(bit_matrix), i + 1))) if _f] \
|
||||
for j in range(min(len(bit_matrix), i + 1))) \
|
||||
if not is_zero(_f)] \
|
||||
for i in range(len(bit_matrix[0]))]
|
||||
return cls.wallace_tree_from_columns(columns, get_carry)
|
||||
|
||||
@classmethod
|
||||
def wallace_tree_from_columns(cls, columns, get_carry=True):
|
||||
def wallace_tree_without_finish(cls, columns, get_carry=True):
|
||||
self = cls
|
||||
while max(len(c) for c in columns) > 2:
|
||||
new_columns = [[] for i in range(len(columns) + 1)]
|
||||
@@ -2296,7 +2398,12 @@ class _bitint(object):
|
||||
columns = new_columns[:-1]
|
||||
for col in columns:
|
||||
col.extend([0] * (2 - len(col)))
|
||||
return self.bit_adder(*list(zip(*columns)))
|
||||
return tuple(list(x) for x in zip(*columns))
|
||||
|
||||
@classmethod
|
||||
def wallace_tree_from_columns(cls, columns, get_carry=True):
|
||||
summands = cls.wallace_tree_without_finish(columns, get_carry)
|
||||
return cls.bit_adder(*summands)
|
||||
|
||||
@classmethod
|
||||
def wallace_tree(cls, rows):
|
||||
@@ -2556,15 +2663,15 @@ def parse_type(other, k=None, f=None):
|
||||
if isinstance(other, cfix.scalars):
|
||||
return cfix(other, k=k, f=f)
|
||||
elif isinstance(other, cint):
|
||||
tmp = cfix()
|
||||
tmp = cfix(k=k, f=f)
|
||||
tmp.load_int(other)
|
||||
return tmp
|
||||
elif isinstance(other, sint):
|
||||
tmp = sfix()
|
||||
tmp = sfix(k=k, f=f)
|
||||
tmp.load_int(other)
|
||||
return tmp
|
||||
elif isinstance(other, sfloat):
|
||||
tmp = sfix(other)
|
||||
tmp = sfix(other, k=k, f=f)
|
||||
return tmp
|
||||
else:
|
||||
return other
|
||||
@@ -2631,8 +2738,8 @@ class cfix(_number, _structure):
|
||||
@vectorize_init
|
||||
def __init__(self, v=None, k=None, f=None, size=None):
|
||||
""" :param v: cfix/float/int """
|
||||
f = f or self.f
|
||||
k = k or self.k
|
||||
f = self.f if f is None else f
|
||||
k = self.k if k is None else k
|
||||
self.f = f
|
||||
self.k = k
|
||||
self.size = get_global_vector_size()
|
||||
@@ -2693,7 +2800,9 @@ class cfix(_number, _structure):
|
||||
def mul(self, other):
|
||||
""" Clear fixed-point multiplication.
|
||||
|
||||
:param other: cfix/cint/regint/int """
|
||||
:param other: cfix/cint/regint/int/sint """
|
||||
if isinstance(other, sint):
|
||||
return sfix._new(self.v * other, k=self.k, f=self.f)
|
||||
other = parse_type(other)
|
||||
if isinstance(other, cfix):
|
||||
assert self.f == other.f
|
||||
@@ -2814,8 +2923,12 @@ class cfix(_number, _structure):
|
||||
if isinstance(other, cfix):
|
||||
return cfix(library.cint_cint_division(self.v, other.v, self.k, self.f))
|
||||
elif isinstance(other, sfix):
|
||||
return sfix(library.FPDiv(self.v, other.v, self.k, self.f,
|
||||
other.kappa, nearest=sfix.round_nearest))
|
||||
assert self.k == other.k
|
||||
assert self.f == other.f
|
||||
return sfix._new(library.FPDiv(self.v, other.v, self.k, self.f,
|
||||
other.kappa,
|
||||
nearest=sfix.round_nearest),
|
||||
k=self.k, f=self.f)
|
||||
else:
|
||||
raise TypeError('Incompatible fixed point types in division')
|
||||
|
||||
@@ -3022,25 +3135,29 @@ class _fix(_single):
|
||||
""" Convert secret integer.
|
||||
|
||||
:param other: sint """
|
||||
res = cls()
|
||||
res.f = f or cls.f
|
||||
res.k = k or cls.k
|
||||
res = cls(k=k, f=f)
|
||||
res.load_int(cls.int_type.conv(other))
|
||||
return res
|
||||
|
||||
@classmethod
|
||||
def _new(cls, other, k=None, f=None):
|
||||
res = cls(other)
|
||||
res.k = k or cls.k
|
||||
res.f = f or cls.f
|
||||
res = cls(other, k=k, f=f)
|
||||
return res
|
||||
|
||||
@vectorize_init
|
||||
def __init__(self, _v=None, size=None):
|
||||
def __init__(self, _v=None, k=None, f=None, size=None):
|
||||
""" :params _v: compile-time value (int/float) """
|
||||
self.size = get_global_vector_size()
|
||||
f = self.f
|
||||
k = self.k
|
||||
if k is None:
|
||||
k = self.k
|
||||
else:
|
||||
self.k = k
|
||||
if f is None:
|
||||
f = self.f
|
||||
else:
|
||||
self.f = f
|
||||
assert k is not None
|
||||
assert f is not None
|
||||
# warning: don't initialize a sfix from a sint, this is only used in internal methods;
|
||||
# for external initialization use load_int.
|
||||
if _v is None:
|
||||
@@ -3110,7 +3227,7 @@ class _fix(_single):
|
||||
val = self.v.TruncMul(other.v, self.k + other.k, other.f,
|
||||
self.kappa,
|
||||
self.round_nearest)
|
||||
if self.size >= other.size:
|
||||
if 'vec' not in self.__dict__:
|
||||
return self._new(val, k=self.k, f=self.f)
|
||||
else:
|
||||
return self.vec._new(val, k=self.k, f=self.f)
|
||||
@@ -3123,7 +3240,7 @@ class _fix(_single):
|
||||
@vectorize
|
||||
def __neg__(self):
|
||||
""" Secret fixed-point negation. """
|
||||
return type(self)(-self.v)
|
||||
return self._new(-self.v, k=self.k, f=self.f)
|
||||
|
||||
@vectorize
|
||||
def __truediv__(self, other):
|
||||
@@ -3131,14 +3248,17 @@ class _fix(_single):
|
||||
|
||||
:param other: sfix/cfix/sint/cint/regint/int """
|
||||
other = self.coerce(other)
|
||||
assert self.k == other.k
|
||||
assert self.f == other.f
|
||||
if isinstance(other, _fix):
|
||||
return type(self)(library.FPDiv(self.v, other.v, self.k, self.f,
|
||||
self.kappa,
|
||||
nearest=self.round_nearest))
|
||||
v = library.FPDiv(self.v, other.v, self.k, self.f, self.kappa,
|
||||
nearest=self.round_nearest)
|
||||
elif isinstance(other, cfix):
|
||||
return type(self)(library.sint_cint_division(self.v, other.v, self.k, self.f, self.kappa))
|
||||
v = library.sint_cint_division(self.v, other.v, self.k, self.f,
|
||||
self.kappa)
|
||||
else:
|
||||
raise TypeError('Incompatible fixed point types in division')
|
||||
return self._new(v, k=self.k, f=self.f)
|
||||
|
||||
@vectorize
|
||||
def __rtruediv__(self, other):
|
||||
@@ -3192,11 +3312,21 @@ class sfix(_fix):
|
||||
lower = average - 0.5 * 2 ** log_range
|
||||
return cls._new(cls.int_type.get_random_int(n_bits)) + lower
|
||||
|
||||
@classmethod
|
||||
def direct_matrix_mul(cls, A, B, n, m, l, reduce=True, indices=None):
|
||||
# pre-multiplication must be identity
|
||||
tmp = cls.int_type.direct_matrix_mul(A, B, n, m, l, indices=indices)
|
||||
res = unreduced_sfix._new(tmp)
|
||||
if reduce:
|
||||
res = res.reduce_after_mul()
|
||||
return res
|
||||
|
||||
def coerce(self, other):
|
||||
return parse_type(other, k=self.k, f=self.f)
|
||||
|
||||
def mul_no_reduce(self, other, res_params=None):
|
||||
assert self.f == other.f
|
||||
assert self.k == other.k
|
||||
return self.unreduced(self.v * other.v)
|
||||
|
||||
def pre_mul(self):
|
||||
@@ -3221,6 +3351,8 @@ class unreduced_sfix(_single):
|
||||
self.k = k
|
||||
self.m = m
|
||||
self.kappa = kappa
|
||||
assert self.k is not None
|
||||
assert self.m is not None
|
||||
|
||||
def __add__(self, other):
|
||||
if is_zero(other):
|
||||
@@ -3236,7 +3368,8 @@ class unreduced_sfix(_single):
|
||||
def reduce_after_mul(self):
|
||||
return sfix(sfix.int_type.round(self.v, self.k, self.m, self.kappa,
|
||||
nearest=sfix.round_nearest,
|
||||
signed=True))
|
||||
signed=True),
|
||||
k=self.k // 2, f=self.m)
|
||||
|
||||
sfix.unreduced_type = unreduced_sfix
|
||||
|
||||
@@ -4012,12 +4145,15 @@ class Array(object):
|
||||
pass
|
||||
try:
|
||||
other.store_in_mem(self.get_address(base))
|
||||
assert len(self) >= other.size + base
|
||||
if len(self) != None and util.is_constant(base):
|
||||
assert len(self) >= other.size + base
|
||||
except AttributeError:
|
||||
for i,j in enumerate(other):
|
||||
self[i] = j
|
||||
return self
|
||||
|
||||
assign_vector = assign
|
||||
|
||||
def assign_all(self, value, use_threads=True, conv=True):
|
||||
""" Assign the same value to all entries.
|
||||
|
||||
@@ -4040,6 +4176,19 @@ class Array(object):
|
||||
size = size or self.length
|
||||
return self.value_type.load_mem(self.get_address(base), size=size)
|
||||
|
||||
get_part_vector = get_vector
|
||||
|
||||
def get(self, indices):
|
||||
return self.value_type.load_mem(
|
||||
regint(self.address, size=len(indices)) + indices,
|
||||
size=len(indices))
|
||||
|
||||
def expand_to_vector(self, index, size):
|
||||
assert self.value_type.n_elements() == 1
|
||||
address = regint(size=size)
|
||||
incint(address, regint(self.get_address(index), size=1), 0)
|
||||
return self.value_type.load_mem(address, size=size)
|
||||
|
||||
def get_mem_value(self, index):
|
||||
return MemValue(self[index], self.get_address(index))
|
||||
|
||||
@@ -4082,12 +4231,15 @@ class Array(object):
|
||||
|
||||
def shuffle(self):
|
||||
""" Insecure shuffle in place. """
|
||||
@library.for_range(len(self))
|
||||
def _(i):
|
||||
j = regint.get_random(64) % (len(self) - i)
|
||||
tmp = self[i]
|
||||
self[i] = self[i + j]
|
||||
self[i + j] = tmp
|
||||
if self.value_type == regint:
|
||||
self.assign(self.get_vector().shuffle())
|
||||
else:
|
||||
@library.for_range(len(self))
|
||||
def _(i):
|
||||
j = regint.get_random(64) % (len(self) - i)
|
||||
tmp = self[i]
|
||||
self[i] = self[i + j]
|
||||
self[i + j] = tmp
|
||||
|
||||
def reveal(self):
|
||||
""" Reveal the whole array.
|
||||
@@ -4184,6 +4336,14 @@ class SubMultiArray(object):
|
||||
assert self.sizes == other.sizes
|
||||
self.assign_vector(other.get_vector())
|
||||
|
||||
def get_part_vector(self, base=0, size=None):
|
||||
assert self.value_type.n_elements() == 1
|
||||
part_size = reduce(operator.mul, self.sizes[1:])
|
||||
size = (size or len(self)) * part_size
|
||||
assert size <= self.total_size()
|
||||
return self.value_type.load_mem(self.address + base * part_size,
|
||||
size=size)
|
||||
|
||||
def same_shape(self):
|
||||
""" :return: new multidimensional array with same shape and basic type """
|
||||
return MultiArray(self.sizes, self.value_type)
|
||||
@@ -4254,6 +4414,7 @@ class SubMultiArray(object):
|
||||
for i, x in enumerate(other):
|
||||
matrix[i][0] = x
|
||||
res = self * matrix
|
||||
library.break_point()
|
||||
return Array.create_from(x[0] for x in res)
|
||||
elif isinstance(other, SubMultiArray):
|
||||
assert len(other.sizes) == 2
|
||||
@@ -4292,6 +4453,32 @@ class SubMultiArray(object):
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
def direct_mul(self, other, reduce=True, indices=None):
|
||||
""" Matrix multiplication in the virtual machine.
|
||||
|
||||
:param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
|
||||
:param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
|
||||
:param indices: 4-tuple of :py:class:`regint` vectors for index selection (default is complete multiplication)
|
||||
:return: Matrix as vector of relevant type (row-major)
|
||||
|
||||
The following executes a matrix multiplication selecting every third row
|
||||
of :py:obj:`A`::
|
||||
|
||||
A = sfix.Matrix(7, 4)
|
||||
B = sfix.Matrix(4, 5)
|
||||
C = sfix.Matrix(3, 5)
|
||||
C.assign_vector(A.direct_mul(B, indices=(regint.inc(3, 0, 3),
|
||||
regint.inc(4),
|
||||
regint.inc(4),
|
||||
regint.inc(5)))
|
||||
"""
|
||||
assert len(self.sizes) == 2
|
||||
assert len(other.sizes) == 2
|
||||
assert self.sizes[1] == other.sizes[0]
|
||||
return self.value_type.direct_matrix_mul(self.address, other.address,
|
||||
self.sizes[0], *other.sizes,
|
||||
reduce=reduce, indices=indices)
|
||||
|
||||
def budget_mul(self, other, n_rows, row, n_columns, column, reduce=True,
|
||||
res=None):
|
||||
assert len(self.sizes) == 2
|
||||
@@ -4349,6 +4536,26 @@ class SubMultiArray(object):
|
||||
lambda x, j: [x[k][j] for k in range(len(x))],
|
||||
reduce=reduce, res=res)
|
||||
|
||||
def parallel_mul(self, other):
|
||||
assert self.sizes[1] == other.sizes[0]
|
||||
assert len(self.sizes) == 2
|
||||
assert len(other.sizes) == 2
|
||||
assert self.value_type.n_elements() == 1
|
||||
n = self.sizes[0] * other.sizes[1]
|
||||
a = []
|
||||
b = []
|
||||
for i in range(self.sizes[1]):
|
||||
addresses = regint(size=n)
|
||||
incint(addresses, regint(self.address + i), self.sizes[1],
|
||||
other.sizes[1], n)
|
||||
a.append(self.value_type.load_mem(addresses, size=n))
|
||||
addresses = regint(size=n)
|
||||
incint(addresses, regint(other.address + i * other.sizes[1]), 1,
|
||||
1, other.sizes[1])
|
||||
b.append(self.value_type.load_mem(addresses, size=n))
|
||||
res = self.value_type.dot_product(a, b)
|
||||
return res
|
||||
|
||||
def transpose(self):
|
||||
""" Matrix transpose.
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ public:
|
||||
bool fewer_rounds;
|
||||
bool check_open;
|
||||
bool check_beaver_open;
|
||||
bool R_after_msg;
|
||||
|
||||
EcdsaOptions(ez::ezOptionParser& opt, int argc, const char** argv)
|
||||
{
|
||||
@@ -54,11 +55,21 @@ public:
|
||||
"-B", // Flag token.
|
||||
"--no-beaver-open-check" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Only open R after message is known", // Help description.
|
||||
"-R", // Flag token.
|
||||
"--R-after-msg" // Flag token.
|
||||
);
|
||||
opt.parse(argc, argv);
|
||||
prep_mul = not opt.isSet("-D");
|
||||
fewer_rounds = opt.isSet("-P");
|
||||
check_open = not opt.isSet("-C");
|
||||
check_beaver_open = not opt.isSet("-B");
|
||||
R_after_msg = opt.isSet("-R");
|
||||
opt.resetArgs();
|
||||
}
|
||||
};
|
||||
|
||||
@@ -17,6 +17,10 @@
|
||||
#include "Processor/Input.hpp"
|
||||
#include "Processor/Processor.hpp"
|
||||
#include "Processor/Data_Files.hpp"
|
||||
#include "Protocols/MascotPrep.hpp"
|
||||
#include "GC/Secret.hpp"
|
||||
#include "GC/TinyPrep.hpp"
|
||||
#include "OT/NPartyTripleGenerator.hpp"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
@@ -41,6 +45,10 @@ int main(int argc, const char** argv)
|
||||
Sub_Data_Files<pShare> prep(N, prefix, usage);
|
||||
typename pShare::Direct_MC MCp(keyp);
|
||||
ArithmeticProcessor _({}, 0);
|
||||
BaseMachine machine;
|
||||
machine.ot_setups.push_back({P, false});
|
||||
GC::ShareThread<typename pShare::bit_type> thread(N,
|
||||
OnlineOptions::singleton, P, {}, usage);
|
||||
SubProcessor<pShare> proc(_, MCp, prep, P);
|
||||
|
||||
pShare sk, __;
|
||||
|
||||
@@ -11,11 +11,14 @@
|
||||
#include "Math/gfp.h"
|
||||
#include "ECDSA/P256Element.h"
|
||||
#include "Tools/Bundle.h"
|
||||
#include "GC/TinyMC.h"
|
||||
#include "GC/MaliciousCcdSecret.h"
|
||||
#include "GC/CcdSecret.h"
|
||||
#include "GC/VectorInput.h"
|
||||
|
||||
#include "ECDSA/preprocessing.hpp"
|
||||
#include "ECDSA/sign.hpp"
|
||||
#include "Protocols/MaliciousRepMC.hpp"
|
||||
#include "Protocols/MaliciousRepPrep.hpp"
|
||||
#include "Protocols/Beaver.hpp"
|
||||
#include "Protocols/fake-stuff.hpp"
|
||||
#include "Processor/Input.hpp"
|
||||
@@ -24,6 +27,8 @@
|
||||
#include "GC/ShareSecret.hpp"
|
||||
#include "GC/RepPrep.hpp"
|
||||
#include "GC/ThreadMaster.hpp"
|
||||
#include "GC/Secret.hpp"
|
||||
#include "Machines/ShamirMachine.hpp"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
|
||||
@@ -7,8 +7,6 @@
|
||||
|
||||
#include "Protocols/Shamir.hpp"
|
||||
#include "Protocols/ShamirInput.hpp"
|
||||
#include "Protocols/ShamirMC.hpp"
|
||||
#include "Protocols/MaliciousShamirMC.hpp"
|
||||
|
||||
#include "hm-ecdsa-party.hpp"
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
#include "Processor/Processor.hpp"
|
||||
#include "Processor/Data_Files.hpp"
|
||||
#include "Processor/Input.hpp"
|
||||
#include "GC/TinyPrep.hpp"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
@@ -103,6 +104,8 @@ void run(int argc, const char** argv)
|
||||
typename pShare::Direct_MC MCp(keyp, N, 0);
|
||||
ArithmeticProcessor _({}, 0);
|
||||
typename pShare::LivePrep sk_prep(0, usage);
|
||||
GC::ShareThread<typename pShare::bit_type> thread(N,
|
||||
OnlineOptions::singleton, P, {}, usage);
|
||||
SubProcessor<pShare> sk_proc(_, MCp, sk_prep, P);
|
||||
pShare sk, __;
|
||||
// synchronize
|
||||
|
||||
@@ -11,6 +11,13 @@
|
||||
#include "Processor/Data_Files.h"
|
||||
#include "Protocols/ReplicatedPrep.h"
|
||||
#include "Protocols/MaliciousShamirShare.h"
|
||||
#include "GC/TinierSecret.h"
|
||||
#include "GC/TinierPrep.h"
|
||||
#include "GC/MaliciousCcdSecret.h"
|
||||
#include "GC/TinyMC.h"
|
||||
|
||||
#include "GC/TinierSharePrep.hpp"
|
||||
#include "GC/CcdSecret.h"
|
||||
|
||||
template<template<class U> class T>
|
||||
class EcTuple
|
||||
@@ -18,6 +25,8 @@ class EcTuple
|
||||
public:
|
||||
T<P256Element::Scalar> a;
|
||||
T<P256Element::Scalar> b;
|
||||
P256Element::Scalar c;
|
||||
T<P256Element> secret_R;
|
||||
P256Element R;
|
||||
};
|
||||
|
||||
@@ -62,9 +71,10 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
secret_Rs.push_back(bs[i] / cs_opened[i]);
|
||||
}
|
||||
vector<P256Element> opened_Rs;
|
||||
vector<P256Element> opened_Rs(buffer_size);
|
||||
typename cShare::Direct_MC MCc(MCp.get_alphai());
|
||||
MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player);
|
||||
if (not opts.R_after_msg)
|
||||
MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player);
|
||||
if (prep_mul)
|
||||
{
|
||||
protocol.init_mul(&proc);
|
||||
@@ -74,10 +84,13 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
|
||||
}
|
||||
if (opts.fewer_rounds)
|
||||
MCp.POpen_End(cs_opened, cs, extra_player);
|
||||
MCc.POpen_End(opened_Rs, secret_Rs, extra_player);
|
||||
if (opts.fewer_rounds)
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
opened_Rs[i] /= cs_opened[i];
|
||||
if (not opts.R_after_msg)
|
||||
{
|
||||
MCc.POpen_End(opened_Rs, secret_Rs, extra_player);
|
||||
if (opts.fewer_rounds)
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
opened_Rs[i] /= cs_opened[i];
|
||||
}
|
||||
if (prep_mul)
|
||||
protocol.stop_exchange();
|
||||
if (opts.check_open)
|
||||
@@ -88,7 +101,7 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
|
||||
{
|
||||
tuples.push_back(
|
||||
{ inv_ks[i], prep_mul ? protocol.finalize_mul() : pShare(),
|
||||
opened_Rs[i] });
|
||||
cs_opened[i], secret_Rs[i], opened_Rs[i] });
|
||||
}
|
||||
timer.stop();
|
||||
cout << "Generated " << buffer_size << " tuples in " << timer.elapsed()
|
||||
@@ -112,6 +125,7 @@ void check(vector<EcTuple<T>>& tuples, T<P256Element::Scalar> sk,
|
||||
assert(open_sk * inv_k == MC.POpen(tuple.b, P));
|
||||
assert(tuple.R == k);
|
||||
}
|
||||
MC.Check(P);
|
||||
}
|
||||
|
||||
template<>
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
|
||||
#include "Protocols/Shamir.hpp"
|
||||
#include "Protocols/ShamirInput.hpp"
|
||||
#include "Protocols/ShamirMC.hpp"
|
||||
|
||||
#include "hm-ecdsa-party.hpp"
|
||||
|
||||
|
||||
@@ -49,7 +49,10 @@ inline P256Element::Scalar hash_to_scalar(const unsigned char* message, size_t l
|
||||
template<template<class U> class T>
|
||||
EcSignature sign(const unsigned char* message, size_t length,
|
||||
EcTuple<T> tuple,
|
||||
typename T<P256Element::Scalar>::MAC_Check& MC, Player& P,
|
||||
typename T<P256Element::Scalar>::MAC_Check& MC,
|
||||
typename T<P256Element>::MAC_Check& MCc,
|
||||
Player& P,
|
||||
EcdsaOptions opts,
|
||||
P256Element pk,
|
||||
T<P256Element::Scalar> sk = {},
|
||||
SubProcessor<T<P256Element::Scalar>>* proc = 0)
|
||||
@@ -60,16 +63,30 @@ EcSignature sign(const unsigned char* message, size_t length,
|
||||
size_t start = P.sent;
|
||||
auto stats = P.comm_stats;
|
||||
EcSignature signature;
|
||||
signature.R = tuple.R;
|
||||
vector<P256Element> opened_R;
|
||||
if (opts.R_after_msg)
|
||||
MCc.POpen_Begin(opened_R, {tuple.secret_R}, P);
|
||||
T<P256Element::Scalar> prod = tuple.b;
|
||||
auto& protocol = proc->protocol;
|
||||
if (proc)
|
||||
{
|
||||
auto& protocol = proc->protocol;
|
||||
protocol.init_mul(proc);
|
||||
protocol.prepare_mul(sk, tuple.a);
|
||||
protocol.exchange();
|
||||
protocol.start_exchange();
|
||||
}
|
||||
if (opts.R_after_msg)
|
||||
{
|
||||
MCc.POpen_End(opened_R, {tuple.secret_R}, P);
|
||||
tuple.R = opened_R[0];
|
||||
if (opts.fewer_rounds)
|
||||
tuple.R /= tuple.c;
|
||||
}
|
||||
if (proc)
|
||||
{
|
||||
protocol.stop_exchange();
|
||||
prod = protocol.finalize_mul();
|
||||
}
|
||||
signature.R = tuple.R;
|
||||
auto rx = tuple.R.x();
|
||||
signature.s = MC.open(
|
||||
tuple.a * hash_to_scalar(message, length) + prod * rx, P);
|
||||
@@ -132,7 +149,7 @@ void sign_benchmark(vector<EcTuple<T>>& tuples, T<P256Element::Scalar> sk,
|
||||
|
||||
for (size_t i = 0; i < min(10lu, tuples.size()); i++)
|
||||
{
|
||||
check(sign(message, 1 << i, tuples[i], MCp, P, pk, sk, proc), message,
|
||||
check(sign(message, 1 << i, tuples[i], MCp, MCc, P, opts, pk, sk, proc), message,
|
||||
1 << i, pk);
|
||||
if (not opts.check_open)
|
||||
continue;
|
||||
@@ -142,6 +159,7 @@ void sign_benchmark(vector<EcTuple<T>>& tuples, T<P256Element::Scalar> sk,
|
||||
auto stats = check_player.comm_stats;
|
||||
auto start = check_player.sent;
|
||||
MCp.Check(P);
|
||||
MCc.Check(P);
|
||||
cout << "Online checking took " << timer.elapsed() * 1e3 << " ms and sending "
|
||||
<< (check_player.sent - start) << " bytes" << endl;
|
||||
auto diff = (check_player.comm_stats - stats);
|
||||
|
||||
@@ -221,4 +221,12 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
class ran_out
|
||||
{
|
||||
const char* what() const
|
||||
{
|
||||
return "insufficient preprocessing";
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
45
FHE/AddableVector.cpp
Normal file
45
FHE/AddableVector.cpp
Normal file
@@ -0,0 +1,45 @@
|
||||
/*
|
||||
* AddableVector.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "AddableVector.h"
|
||||
#include "Rq_Element.h"
|
||||
#include "FHE_Keys.h"
|
||||
|
||||
template<class T>
|
||||
AddableVector<T> AddableVector<T>::mul_by_X_i(int j,
|
||||
const FHE_PK& pk) const
|
||||
{
|
||||
int phi_m = this->size();
|
||||
assert(phi_m == pk.get_params().phi_m());
|
||||
AddableVector res(phi_m);
|
||||
for (int i = 0; i < phi_m; i++)
|
||||
{
|
||||
int k = j + i, s = 1;
|
||||
while (k >= phi_m)
|
||||
{
|
||||
k -= phi_m;
|
||||
s = -s;
|
||||
}
|
||||
if (s == 1)
|
||||
{
|
||||
res[k] = (*this)[i];
|
||||
}
|
||||
else
|
||||
{
|
||||
res[k] = -(*this)[i];
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
template
|
||||
AddableVector<fixint<0>> AddableVector<fixint<0>>::mul_by_X_i(int j,
|
||||
const FHE_PK& pk) const;
|
||||
template
|
||||
AddableVector<fixint<1>> AddableVector<fixint<1>>::mul_by_X_i(int j,
|
||||
const FHE_PK& pk) const;
|
||||
template
|
||||
AddableVector<fixint<2>> AddableVector<fixint<2>>::mul_by_X_i(int j,
|
||||
const FHE_PK& pk) const;
|
||||
@@ -10,6 +10,7 @@
|
||||
using namespace std;
|
||||
|
||||
#include "FHE/Plaintext.h"
|
||||
#include "Rq_Element.h"
|
||||
|
||||
template<class T>
|
||||
class AddableVector: public vector<T>
|
||||
@@ -27,6 +28,11 @@ public:
|
||||
this->assign(other.begin(), other.end());
|
||||
}
|
||||
|
||||
AddableVector(const Rq_Element& other) :
|
||||
AddableVector(other.to_vec_bigint())
|
||||
{
|
||||
}
|
||||
|
||||
template <class U>
|
||||
void allocate_slots(const U& init)
|
||||
{
|
||||
@@ -66,6 +72,8 @@ public:
|
||||
(*this)[i].mul(x[i], y[i]);
|
||||
}
|
||||
|
||||
AddableVector mul_by_X_i(int i, const FHE_PK& pk) const;
|
||||
|
||||
void generateUniform(PRNG& G, int n_bits)
|
||||
{
|
||||
for (unsigned int i = 0; i < this->size(); i++)
|
||||
@@ -171,6 +179,19 @@ public:
|
||||
for (int i = 0; i < n; i++)
|
||||
(*this)[i].resize(m);
|
||||
}
|
||||
|
||||
AddableMatrix mul_by_X_i(int i, const FHE_PK& pk) const;
|
||||
};
|
||||
|
||||
template<class T>
|
||||
AddableMatrix<T> AddableMatrix<T>::mul_by_X_i(int i,
|
||||
const FHE_PK& pk) const
|
||||
{
|
||||
AddableMatrix<T> res;
|
||||
res.resize(this->size());
|
||||
for (size_t j = 0; j < this->size(); j++)
|
||||
res[j] = (*this)[j].mul_by_X_i(i, pk);
|
||||
return res;
|
||||
}
|
||||
|
||||
#endif /* FHEOFFLINE_ADDABLEVECTOR_H_ */
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "Ciphertext.h"
|
||||
#include "PPData.h"
|
||||
#include "P2Data.h"
|
||||
#include "Exceptions/Exceptions.h"
|
||||
|
||||
Ciphertext::Ciphertext(const FHE_PK& pk) : Ciphertext(pk.get_params())
|
||||
@@ -118,12 +119,11 @@ template<class T,class FD,class S>
|
||||
void mul(Ciphertext& ans,const Plaintext<T,FD,S>& a,const Ciphertext& c)
|
||||
{
|
||||
a.to_poly();
|
||||
const vector<S>& aa=a.get_poly();
|
||||
|
||||
int lev=c.cc0.level();
|
||||
Rq_Element ra((*ans.params).FFTD(),evaluation,evaluation);
|
||||
if (lev==0) { ra.lower_level(); }
|
||||
ra.from_vec(aa);
|
||||
ra.from(a.get_iterator());
|
||||
ans.mul(c, ra);
|
||||
}
|
||||
|
||||
|
||||
@@ -35,11 +35,17 @@ class Ciphertext
|
||||
|
||||
Ciphertext(const FHE_PK &pk);
|
||||
|
||||
Ciphertext(const Rq_Element& a0, const Rq_Element& a1, const Ciphertext& C) :
|
||||
Ciphertext(C.get_params())
|
||||
{
|
||||
set(a0, a1, C.get_pk_id());
|
||||
}
|
||||
|
||||
~Ciphertext() { ; }
|
||||
|
||||
// Rely on default copy assignment/constructor
|
||||
|
||||
word get_pk_id() { return pk_id; }
|
||||
word get_pk_id() const { return pk_id; }
|
||||
|
||||
void set(const Rq_Element& a0, const Rq_Element& a1, word pk_id)
|
||||
{ cc0=a0; cc1=a1; this->pk_id = pk_id; }
|
||||
@@ -93,9 +99,14 @@ class Ciphertext
|
||||
template <class FD>
|
||||
Ciphertext& operator*=(const Plaintext_<FD>& other) { ::mul(*this, *this, other); return *this; }
|
||||
|
||||
Ciphertext mul(const Ciphertext& x, const FHE_PK& pk) const
|
||||
Ciphertext mul(const FHE_PK& pk, const Ciphertext& x) const
|
||||
{ Ciphertext res(*params); ::mul(res, *this, x, pk); return res; }
|
||||
|
||||
Ciphertext mul_by_X_i(int i, const FHE_PK&) const
|
||||
{
|
||||
return {cc0.mul_by_X_i(i), cc1.mul_by_X_i(i), *this};
|
||||
}
|
||||
|
||||
int level() const { return cc0.level(); }
|
||||
|
||||
// pack/unpack (like IO) also assume params are known and already set
|
||||
|
||||
@@ -4,34 +4,30 @@
|
||||
|
||||
void DiscreteGauss::set(double RR)
|
||||
{
|
||||
R=RR;
|
||||
e=exp(1); e1=exp(0.25); e2=exp(-1.35);
|
||||
if (RR > 0 or NewHopeB < 1)
|
||||
NewHopeB = max(1, int(round(2 * RR * RR)));
|
||||
assert(NewHopeB > 0);
|
||||
}
|
||||
|
||||
|
||||
|
||||
/* Return a value distributed normaly with std dev R */
|
||||
int DiscreteGauss::sample(PRNG& G, int stretch) const
|
||||
/* This uses the approximation to a Gaussian via
|
||||
* binomial distribution
|
||||
*
|
||||
* This procedure consumes 2*NewHopeB bits
|
||||
*
|
||||
*/
|
||||
int DiscreteGauss::sample(PRNG &G, int stretch) const
|
||||
{
|
||||
/* Uses the ratio method from Wikipedia to get a
|
||||
Normal(0,1) variable X
|
||||
Then multiplies X by R
|
||||
*/
|
||||
double U,V,X,R1,R2,R3,X2;
|
||||
int ans;
|
||||
while (true)
|
||||
{ U=G.get_double();
|
||||
V=G.get_double();
|
||||
R1=5-4*e1*U;
|
||||
R2=4*e2/U+1.4;
|
||||
R3=-4/log(U);
|
||||
X=sqrt(8/e)*(V-0.5)/U;
|
||||
X2=X*X;
|
||||
if (X2<=R1 || (X2<R2 && X2<=R3))
|
||||
{ ans=(int) (X*R*stretch);
|
||||
return ans;
|
||||
}
|
||||
int s= 0;
|
||||
// stretch refers to the standard deviation
|
||||
int B = NewHopeB * stretch * stretch;
|
||||
for (int i = 0; i < B; i++)
|
||||
{
|
||||
s += G.get_bit();
|
||||
s -= G.get_bit();
|
||||
}
|
||||
return s;
|
||||
}
|
||||
|
||||
|
||||
@@ -39,7 +35,8 @@ int DiscreteGauss::sample(PRNG& G, int stretch) const
|
||||
void RandomVectors::set(int nn,int hh,double R)
|
||||
{
|
||||
n=nn;
|
||||
h=hh;
|
||||
if (h > 0)
|
||||
h=hh;
|
||||
DG.set(R);
|
||||
}
|
||||
|
||||
@@ -124,7 +121,7 @@ bool RandomVectors::operator!=(const RandomVectors& other) const
|
||||
|
||||
bool DiscreteGauss::operator!=(const DiscreteGauss& other) const
|
||||
{
|
||||
if (other.R != R or other.e != e or other.e1 != e1 or other.e2 != e2)
|
||||
if (other.NewHopeB != NewHopeB)
|
||||
return true;
|
||||
else
|
||||
return false;
|
||||
|
||||
@@ -7,35 +7,33 @@
|
||||
|
||||
#include <FHE/Generator.h>
|
||||
#include "Math/modp.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Tools/random.h"
|
||||
#include <vector>
|
||||
|
||||
/* Uses the Ratio method as opposed to the Box-Muller method
|
||||
* as the Ratio method is thread safe, but it is 50% slower
|
||||
*/
|
||||
#include <math.h>
|
||||
|
||||
class DiscreteGauss
|
||||
{
|
||||
double R; // Standard deviation
|
||||
double e; // Precomputed exp(1)
|
||||
double e1; // Precomputed exp(0.25)
|
||||
double e2; // Precomputed exp(-1.35)
|
||||
/* This is the bound we use on for the NewHope approximation
|
||||
* to a discrete Gaussian with sigma=sqrt(B/2)
|
||||
*/
|
||||
int NewHopeB;
|
||||
|
||||
public:
|
||||
|
||||
void set(double R);
|
||||
|
||||
void pack(octetStream& o) const { o.serialize(R); }
|
||||
void unpack(octetStream& o) { o.unserialize(R); }
|
||||
void pack(octetStream& o) const { o.serialize(NewHopeB); }
|
||||
void unpack(octetStream& o) { o.unserialize(NewHopeB); }
|
||||
|
||||
DiscreteGauss() { set(0); }
|
||||
DiscreteGauss(double R) { set(R); }
|
||||
~DiscreteGauss() { ; }
|
||||
|
||||
// Rely on default copy constructor/assignment
|
||||
|
||||
int sample(PRNG& G, int stretch = 1) const;
|
||||
double get_R() const { return R; }
|
||||
double get_R() const { return sqrt(0.5 * NewHopeB); }
|
||||
int get_NewHopeB() const { return NewHopeB; }
|
||||
|
||||
bool operator!=(const DiscreteGauss& other) const;
|
||||
};
|
||||
@@ -56,8 +54,8 @@ class RandomVectors
|
||||
void pack(octetStream& o) const { o.store(n); o.store(h); DG.pack(o); }
|
||||
void unpack(octetStream& o) { o.get(n); o.get(h); DG.unpack(o); }
|
||||
|
||||
RandomVectors() { ; }
|
||||
RandomVectors(int nn,int hh,double R) { set(nn,hh,R); }
|
||||
RandomVectors(int h, double R) : n(0), h(h), DG(R) {}
|
||||
RandomVectors(int nn,int hh,double R) : DG(R) { set(nn,hh,R); }
|
||||
~RandomVectors() { ; }
|
||||
|
||||
// Rely on default copy constructor/assignment
|
||||
@@ -81,7 +79,8 @@ class RandomVectors
|
||||
bool operator!=(const RandomVectors& other) const;
|
||||
};
|
||||
|
||||
class RandomGenerator : public Generator<bigint>
|
||||
template<class T>
|
||||
class RandomGenerator : public Generator<T>
|
||||
{
|
||||
protected:
|
||||
mutable PRNG G;
|
||||
@@ -90,39 +89,42 @@ public:
|
||||
RandomGenerator(PRNG& G) { this->G.SetSeed(G); }
|
||||
};
|
||||
|
||||
class UniformGenerator : public RandomGenerator
|
||||
template<class T>
|
||||
class UniformGenerator : public RandomGenerator<T>
|
||||
{
|
||||
int n_bits;
|
||||
bool positive;
|
||||
|
||||
public:
|
||||
UniformGenerator(PRNG& G, int n_bits, bool positive = true) :
|
||||
RandomGenerator(G), n_bits(n_bits), positive(positive) {}
|
||||
Generator* clone() const { return new UniformGenerator(*this); }
|
||||
void get(bigint& x) const { G.get_bigint(x, n_bits, positive); }
|
||||
RandomGenerator<T>(G), n_bits(n_bits), positive(positive) {}
|
||||
Generator<T>* clone() const { return new UniformGenerator<T>(*this); }
|
||||
void get(T& x) const { this->G.get(x, n_bits, positive); }
|
||||
};
|
||||
|
||||
class GaussianGenerator : public RandomGenerator
|
||||
template<class T>
|
||||
class GaussianGenerator : public RandomGenerator<T>
|
||||
{
|
||||
DiscreteGauss DG;
|
||||
int stretch;
|
||||
|
||||
public:
|
||||
GaussianGenerator(const DiscreteGauss& DG, PRNG& G, int stretch = 1) :
|
||||
RandomGenerator(G), DG(DG), stretch(stretch) {}
|
||||
Generator* clone() const { return new GaussianGenerator(*this); }
|
||||
void get(bigint& x) const { mpz_set_si(x.get_mpz_t(), DG.sample(G, stretch)); }
|
||||
RandomGenerator<T>(G), DG(DG), stretch(stretch) {}
|
||||
Generator<T>* clone() const { return new GaussianGenerator<T>(*this); }
|
||||
void get(T& x) const { x = DG.sample(this->G, stretch); }
|
||||
};
|
||||
|
||||
int sample_half(PRNG& G);
|
||||
|
||||
class HalfGenerator : public RandomGenerator
|
||||
template<class T>
|
||||
class HalfGenerator : public RandomGenerator<T>
|
||||
{
|
||||
public:
|
||||
HalfGenerator(PRNG& G) :
|
||||
RandomGenerator(G) {}
|
||||
Generator* clone() const { return new HalfGenerator(*this); }
|
||||
void get(bigint& x) const { mpz_set_si(x.get_mpz_t(), sample_half(G)); }
|
||||
RandomGenerator<T>(G) {}
|
||||
Generator<T>* clone() const { return new HalfGenerator<T>(*this); }
|
||||
void get(T& x) const { x = sample_half(this->G); }
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#include "FHE/FFT_Data.h"
|
||||
#include "FHE/FFT.h"
|
||||
|
||||
#include "Math/Subroutines.h"
|
||||
#include "FHE/Subroutines.h"
|
||||
|
||||
|
||||
void FFT_Data::assign(const FFT_Data& FFTD)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "Math/modp.h"
|
||||
#include "Math/Zp_Data.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/fixint.h"
|
||||
#include "FHE/Ring.h"
|
||||
|
||||
/* Class for holding modular arithmetic data wrt the ring
|
||||
@@ -36,6 +37,7 @@ class FFT_Data
|
||||
public:
|
||||
typedef gfp T;
|
||||
typedef bigint S;
|
||||
typedef fixint<GFP_MOD_SZ> poly_type;
|
||||
|
||||
void init(const Ring& Rg,const Zp_Data& PrD);
|
||||
|
||||
|
||||
@@ -161,7 +161,7 @@ void FHE_PK::encrypt(Ciphertext& c, const vector<S>& mess,
|
||||
const Random_Coins& rc) const
|
||||
{
|
||||
Rq_Element mm((*params).FFTD(),polynomial,polynomial);
|
||||
mm.from_vec(mess);
|
||||
mm.from(Iterator<S>(mess));
|
||||
quasi_encrypt(c, mm, rc);
|
||||
}
|
||||
|
||||
@@ -302,6 +302,7 @@ void FHE_SK::dist_decrypt_1(vector<bigint>& vv,const Ciphertext& ctx,int player_
|
||||
{ dec_sh.negate(); }
|
||||
|
||||
// Now convert to a vector of bigint's and add the required randomness
|
||||
assert(pr != 0);
|
||||
bigint Bd=((*params).B()<<(*params).secp())/(num_players*pr);
|
||||
Bd=Bd/2; // make slightly smaller due to rounding issues
|
||||
|
||||
@@ -334,6 +335,29 @@ void FHE_SK::dist_decrypt_2(vector<bigint>& vv,const vector<bigint>& vv1) const
|
||||
}
|
||||
}
|
||||
|
||||
void FHE_PK::pack(octetStream& o) const
|
||||
{
|
||||
o.append((octet*) "PKPKPKPK", 8);
|
||||
a0.pack(o);
|
||||
b0.pack(o);
|
||||
Sw_a.pack(o);
|
||||
Sw_b.pack(o);
|
||||
pr.pack(o);
|
||||
}
|
||||
|
||||
void FHE_PK::unpack(octetStream& o)
|
||||
{
|
||||
char tag[8];
|
||||
o.consume((octet*) tag, 8);
|
||||
if (memcmp(tag, "PKPKPKPK", 8))
|
||||
throw runtime_error("invalid serialization of public key");
|
||||
a0.unpack(o);
|
||||
b0.unpack(o);
|
||||
Sw_a.unpack(o);
|
||||
Sw_b.unpack(o);
|
||||
pr.unpack(o);
|
||||
}
|
||||
|
||||
|
||||
bool FHE_PK::operator!=(const FHE_PK& x) const
|
||||
{
|
||||
@@ -379,7 +403,7 @@ template Ciphertext FHE_PK::encrypt(const Plaintext_<P2Data>& mess) const;
|
||||
|
||||
template void FHE_PK::encrypt(Ciphertext& c, const vector<int>& mess,
|
||||
const Random_Coins& rc) const;
|
||||
template void FHE_PK::encrypt(Ciphertext& c, const vector<bigint>& mess,
|
||||
template void FHE_PK::encrypt(Ciphertext& c, const vector<fixint<2>>& mess,
|
||||
const Random_Coins& rc) const;
|
||||
|
||||
template Plaintext_<FFT_Data> FHE_SK::decrypt(const Ciphertext& c,
|
||||
|
||||
@@ -27,7 +27,7 @@ class FHE_SK
|
||||
// secret key always on lower level
|
||||
void assign(const Rq_Element& s) { sk=s; sk.lower_level(); }
|
||||
|
||||
FHE_SK(const FHE_Params& pms, const bigint& p = 0)
|
||||
FHE_SK(const FHE_Params& pms, const bigint& p)
|
||||
: sk(pms.FFTD(),evaluation,evaluation) { params=&pms; pr=p; }
|
||||
|
||||
FHE_SK(const FHE_PK& pk);
|
||||
@@ -110,6 +110,17 @@ class FHE_PK
|
||||
Sw_b(pms.FFTD(),evaluation,evaluation)
|
||||
{ params=&pms; pr=p; }
|
||||
|
||||
FHE_PK(const FHE_Params& pms, int p) :
|
||||
FHE_PK(pms, bigint(p))
|
||||
{
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
FHE_PK(const FHE_Params& pms, const FD& FTD) :
|
||||
FHE_PK(pms, FTD.get_prime())
|
||||
{
|
||||
}
|
||||
|
||||
// Rely on default copy constructor/assignment
|
||||
|
||||
const Rq_Element& a() const { return a0; }
|
||||
@@ -148,10 +159,8 @@ class FHE_PK
|
||||
friend istream& operator>>(istream& s, FHE_PK& PK)
|
||||
{ s >> PK.a0 >> PK.b0 >> PK.Sw_a >> PK.Sw_b; return s; }
|
||||
|
||||
void pack(octetStream& o) const
|
||||
{ a0.pack(o); b0.pack(o); Sw_a.pack(o); Sw_b.pack(o); pr.pack(o); }
|
||||
void unpack(octetStream& o)
|
||||
{ a0.unpack(o); b0.unpack(o); Sw_a.unpack(o); Sw_b.unpack(o); pr.unpack(o); }
|
||||
void pack(octetStream& o) const;
|
||||
void unpack(octetStream& o);
|
||||
|
||||
bool operator!=(const FHE_PK& x) const;
|
||||
|
||||
|
||||
@@ -29,14 +29,14 @@ class FHE_Params
|
||||
|
||||
public:
|
||||
|
||||
FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), sec_p(-1) {}
|
||||
FHE_Params(int n_mults = 1) : FFTData(n_mults + 1), Chi(64, 0.7), sec_p(-1) {}
|
||||
|
||||
int n_mults() const { return FFTData.size() - 1; }
|
||||
|
||||
// Rely on default copy assignment/constructor (not that they should
|
||||
// ever be needed)
|
||||
|
||||
void set(const Ring& R,const vector<bigint>& primes,double r=3.2,int hwt=64);
|
||||
void set(const Ring& R,const vector<bigint>& primes,double r=-1,int hwt=-1);
|
||||
void set_sec(int sec);
|
||||
|
||||
vector<bigint> sampleGaussian(PRNG& G, int noise_boost = 1) const
|
||||
|
||||
@@ -24,40 +24,17 @@ NTL_CLIENT
|
||||
|
||||
#include "FHEOffline/DataSetup.h"
|
||||
|
||||
|
||||
template <>
|
||||
void generate_setup(int n_parties, int plaintext_length, int sec,
|
||||
FHE_Params& params, FFT_Data& FTD, int slack, bool round_up)
|
||||
{
|
||||
Ring Rp;
|
||||
bigint p0p,p1p,p;
|
||||
SPDZ_Data_Setup_Char_p(Rp, FTD, p0p, p1p, n_parties, plaintext_length, sec,
|
||||
slack, round_up);
|
||||
params.set(Rp, {p0p, p1p});
|
||||
}
|
||||
|
||||
|
||||
template <>
|
||||
void generate_setup(int n_parties, int plaintext_length, int sec,
|
||||
FHE_Params& params, P2Data& P2D, int slack, bool round_up)
|
||||
{
|
||||
Ring R;
|
||||
bigint pr0,pr1;
|
||||
SPDZ_Data_Setup_Char_2(R, P2D, pr0, pr1, n_parties, plaintext_length, sec,
|
||||
slack, round_up);
|
||||
params.set(R, {pr0, pr1});
|
||||
}
|
||||
|
||||
|
||||
void generate_setup(int n, int lgp, int lg2, int sec, bool skip_2,
|
||||
int slack, bool round_up)
|
||||
{
|
||||
DataSetup setup;
|
||||
|
||||
// do the full setup for SHE data
|
||||
generate_setup(n, lgp, sec, setup.setup_p.params, setup.setup_p.FieldD, slack, round_up);
|
||||
Parameters(n, lgp, sec, slack, round_up).generate_setup(setup.setup_p.params,
|
||||
setup.setup_p.FieldD);
|
||||
if (!skip_2)
|
||||
generate_setup(n, lg2, sec, setup.setup_2.params, setup.setup_2.FieldD, slack, round_up);
|
||||
Parameters(n, lg2, sec, slack, round_up).generate_setup(
|
||||
setup.setup_2.params, setup.setup_2.FieldD);
|
||||
|
||||
setup.write_setup(skip_2);
|
||||
}
|
||||
@@ -208,9 +185,9 @@ int finalize_lengths(int& lg2p0, int& lg2p1, int n, int m, int* lg2pi, bool roun
|
||||
/*
|
||||
* Subroutine for creating the FHE parameters
|
||||
*/
|
||||
int SPDZ_Data_Setup_Char_p_Sub(Ring& R, bigint& pr0, bigint& pr1, int n,
|
||||
int idx, int& m, bigint& p, int sec, int slack = 0, bool round_up = false)
|
||||
int Parameters::SPDZ_Data_Setup_Char_p_Sub(int idx, int& m, bigint& p)
|
||||
{
|
||||
int n = n_parties;
|
||||
int lg2pi[5][2][9]
|
||||
= { { {130,132,132,132,132,132,132,132,132},
|
||||
{104,104,104,106,106,108,108,110,110} },
|
||||
@@ -291,13 +268,13 @@ void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr,
|
||||
/*
|
||||
* Create the char p FHE parameters
|
||||
*/
|
||||
void SPDZ_Data_Setup_Char_p(Ring& R, FFT_Data& FTD, bigint& pr0, bigint& pr1,
|
||||
int n, int lgp, int sec, int slack, bool round_up)
|
||||
template <>
|
||||
void Parameters::SPDZ_Data_Setup(FFT_Data& FTD)
|
||||
{
|
||||
bigint p;
|
||||
int idx, m;
|
||||
SPDZ_Data_Setup_Primes(p, lgp, idx, m);
|
||||
SPDZ_Data_Setup_Char_p_Sub(R, pr0, pr1, n, idx, m, p, sec, slack, round_up);
|
||||
SPDZ_Data_Setup_Primes(p, plaintext_length, idx, m);
|
||||
SPDZ_Data_Setup_Char_p_Sub(idx, m, p);
|
||||
|
||||
Zp_Data Zp(p);
|
||||
gfp::init_field(p);
|
||||
@@ -574,9 +551,12 @@ void char_2_dimension(int& m, int& lg2)
|
||||
}
|
||||
}
|
||||
|
||||
void SPDZ_Data_Setup_Char_2(Ring& R, P2Data& P2D, bigint& pr0, bigint& pr1,
|
||||
int n, int lg2, int sec, int slack, bool round_up)
|
||||
template <>
|
||||
void Parameters::SPDZ_Data_Setup(P2Data& P2D)
|
||||
{
|
||||
int n = n_parties;
|
||||
int lg2 = plaintext_length;
|
||||
|
||||
int lg2pi[2][9]
|
||||
= { {70,70,70,70,70,70,70,70,70},
|
||||
{70,75,75,75,75,80,80,80,80}
|
||||
@@ -760,7 +740,8 @@ void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0,
|
||||
{ throw bad_value(); }
|
||||
cout << "Chosen value of m=" << m << "\t\t phi(m)=" << bphi_m << " : " << min_hwt << " : " << bmx << endl;
|
||||
|
||||
SPDZ_Data_Setup_Char_p_Sub(R,pr0,pr1,n,idx,m,p,sec);
|
||||
Parameters parameters(n, lgp, sec);
|
||||
parameters.SPDZ_Data_Setup_Char_p_Sub(idx,m,p);
|
||||
int mx=0;
|
||||
for (int i=0; i<R.phi_m(); i++)
|
||||
{ if (mx<R.Phi()[i]) { mx=R.Phi()[i]; } }
|
||||
@@ -769,6 +750,9 @@ void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0,
|
||||
Zp_Data Zp(p);
|
||||
gfp::init_field(p);
|
||||
PPD.init(R,Zp);
|
||||
|
||||
pr0 = parameters.pr0;
|
||||
pr1 = parameters.pr1;
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -9,19 +9,47 @@
|
||||
#include "FHE/PPData.h"
|
||||
#include "FHE/FHE_Params.h"
|
||||
|
||||
|
||||
/* Routines to set up key sizes given the number of players n
|
||||
* - And size lgp of plaintext modulus p for the char p case
|
||||
*/
|
||||
|
||||
class Parameters
|
||||
{
|
||||
int n_parties;
|
||||
int plaintext_length;
|
||||
int sec;
|
||||
int slack;
|
||||
bool round_up;
|
||||
|
||||
public:
|
||||
Ring R;
|
||||
bigint pr0, pr1;
|
||||
|
||||
Parameters(int n_parties, int plaintext_length, int sec, int slack = 0,
|
||||
bool round_up = false) :
|
||||
n_parties(n_parties), plaintext_length(plaintext_length), sec(sec), slack(
|
||||
slack), round_up(round_up)
|
||||
{
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
void generate_setup(FHE_Params& params, FD& FTD)
|
||||
{
|
||||
SPDZ_Data_Setup(FTD);
|
||||
params.set(R, {pr0, pr1});
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
void SPDZ_Data_Setup(FD& FTD);
|
||||
|
||||
int SPDZ_Data_Setup_Char_p_Sub(int idx, int& m, bigint& p);
|
||||
|
||||
};
|
||||
|
||||
// Main setup routine (need NTL if online_only is false)
|
||||
void generate_setup(int nparties, int lgp, int lg2,
|
||||
int sec, bool skip_2 = false, int slack = 0, bool round_up = false);
|
||||
|
||||
template <class FD>
|
||||
void generate_setup(int n_parties, int plaintext_length, int sec,
|
||||
FHE_Params& params, FD& FTD, int slack, bool round_up);
|
||||
|
||||
// semi-homomorphic, includes slack
|
||||
template <class FD>
|
||||
int generate_semi_setup(int plaintext_length, int sec,
|
||||
@@ -35,10 +63,6 @@ int common_semi_setup(FHE_Params& params, int m, bigint p, int lgp0, int lgp1,
|
||||
void init(Ring& Rg,int m);
|
||||
void init(P2Data& P2D,const Ring& Rg);
|
||||
|
||||
// For use when we only care about p being of a certain size
|
||||
void SPDZ_Data_Setup_Char_p(Ring& R, FFT_Data& FTD, bigint& pr0, bigint& pr1,
|
||||
int n, int lgp, int sec, int slack = 0, bool round_up = false);
|
||||
|
||||
// For use when we want p to be a specific value
|
||||
void SPDZ_Data_Setup_Char_p_General(Ring& R, PPData& PPD, bigint& pr0,
|
||||
bigint& pr1, int n, int sec, bigint& p);
|
||||
@@ -53,9 +77,6 @@ void generate_modulus(bigint& pr, const int m, const bigint p, const int lg2pr,
|
||||
// pre-generated dimensions for characteristic 2
|
||||
void char_2_dimension(int& m, int& lg2);
|
||||
|
||||
void SPDZ_Data_Setup_Char_2(Ring& R, P2Data& P2D, bigint& pr0, bigint& pr1,
|
||||
int n, int lg2, int sec = -1, int slacke = 0, bool round_up = false);
|
||||
|
||||
// try to avoid expensive generation by loading from disk if possible
|
||||
void load_or_generate(P2Data& P2D, const Ring& Rg);
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include <FHE/NoiseBounds.h>
|
||||
#include "FHEOffline/Proof.h"
|
||||
#include "Protocols/CowGearOptions.h"
|
||||
#include <math.h>
|
||||
|
||||
|
||||
@@ -13,14 +14,44 @@ SemiHomomorphicNoiseBounds::SemiHomomorphicNoiseBounds(const bigint& p,
|
||||
p(p), phi_m(phi_m), n(n), sec(sec),
|
||||
slack(numBits(Proof::slack(slack_param, sec, phi_m))), sigma(sigma), h(h)
|
||||
{
|
||||
if (sigma <= 0)
|
||||
this->sigma = sigma = FHE_Params().get_R();
|
||||
cerr << "Standard deviation: " << this->sigma << endl;
|
||||
h += extra_h * sec;
|
||||
B_clean = (phi_m * p / 2
|
||||
+ p * sigma
|
||||
* (16 * phi_m * sqrt(n / 2) + 6 * sqrt(phi_m)
|
||||
+ 16 * sqrt(n * h * phi_m))) << slack;
|
||||
B_scale = p * sqrt(3 * phi_m) * (1 + 8 * sqrt(n * h) / 3);
|
||||
drown = 1 + (bigint(1) << sec);
|
||||
cout << "log(slack): " << slack << endl;
|
||||
produce_epsilon_constants();
|
||||
|
||||
if (CowGearOptions::singleton.top_gear())
|
||||
{
|
||||
// according to documentation of SCALE-MAMBA 1.7
|
||||
// excluding a factor of n because we don't always add up n ciphertexts
|
||||
B_clean = (bigint(phi_m) << (sec + 2)) * p
|
||||
* (20.5 + c1 * sigma * sqrt(phi_m) + 20 * c1 * sqrt(h));
|
||||
mpf_class V_s;
|
||||
if (h > 0)
|
||||
V_s = sqrt(h);
|
||||
else
|
||||
V_s = sigma * sqrt(phi_m);
|
||||
B_scale = (c1 + c2 * V_s) * p * sqrt(phi_m / 12.0);
|
||||
#ifdef NOISY
|
||||
cout << "p * sqrt(phi(m) / 12): " << p * sqrt(phi_m / 12.0) << endl;
|
||||
cout << "V_s: " << V_s << endl;
|
||||
cout << "c1: " << c1 << endl;
|
||||
cout << "c2: " << c2 << endl;
|
||||
cout << "c1 + c2 * V_s: " << c1 + c2 * V_s << endl;
|
||||
cout << "B_scale: " << B_scale << endl;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
B_clean = (phi_m * p / 2
|
||||
+ p * sigma
|
||||
* (16 * phi_m * sqrt(n / 2) + 6 * sqrt(phi_m)
|
||||
+ 16 * sqrt(n * h * phi_m))) << slack;
|
||||
B_scale = p * sqrt(3 * phi_m) * (1 + 8 * sqrt(n * h) / 3);
|
||||
cout << "log(slack): " << slack << endl;
|
||||
}
|
||||
|
||||
drown = 1 + n * (bigint(1) << sec);
|
||||
}
|
||||
|
||||
bigint SemiHomomorphicNoiseBounds::min_p0(const bigint& p1)
|
||||
@@ -34,28 +65,62 @@ bigint SemiHomomorphicNoiseBounds::min_p0()
|
||||
return B_clean * drown * p;
|
||||
}
|
||||
|
||||
double SemiHomomorphicNoiseBounds::min_phi_m(int log_q)
|
||||
double SemiHomomorphicNoiseBounds::min_phi_m(int log_q, double sigma)
|
||||
{
|
||||
if (sigma <= 0)
|
||||
sigma = FHE_Params().get_R();
|
||||
// the constant was updated using Martin Albrecht's LWE estimator in Sep 2019
|
||||
return 37.8 * (log_q - log2(3.2));
|
||||
return 37.8 * (log_q - log2(sigma));
|
||||
}
|
||||
|
||||
void SemiHomomorphicNoiseBounds::produce_epsilon_constants()
|
||||
{
|
||||
double C[3];
|
||||
|
||||
for (int i = 0; i < 3; i++)
|
||||
{
|
||||
C[i] = -1;
|
||||
}
|
||||
for (double x = 0.1; x < 10.0; x += .1)
|
||||
{
|
||||
double t = erfc(x), tp = 1;
|
||||
for (int i = 1; i < 3; i++)
|
||||
{
|
||||
tp *= t;
|
||||
double lgtp = log(tp) / log(2.0);
|
||||
if (C[i] < 0 && lgtp < FHE_epsilon)
|
||||
{
|
||||
C[i] = pow(x, i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c1 = C[1];
|
||||
c2 = C[2];
|
||||
}
|
||||
|
||||
NoiseBounds::NoiseBounds(const bigint& p, int phi_m, int n, int sec, int slack,
|
||||
double sigma, int h) :
|
||||
SemiHomomorphicNoiseBounds(p, phi_m, n, sec, slack, false, sigma, h)
|
||||
{
|
||||
B_KS = p * phi_m * sigma
|
||||
* (pow(n, 2.5) * (1.49 * sqrt(h * phi_m) + 2.11 * h)
|
||||
+ 2.77 * n * n * sqrt(h)
|
||||
+ pow(n, 1.5) * (1.96 * sqrt(phi_m) * 2.77 * sqrt(h))
|
||||
+ 4.62 * n);
|
||||
if (CowGearOptions::singleton.top_gear())
|
||||
{
|
||||
B_KS = p * c2 * this->sigma * phi_m / sqrt(12);
|
||||
}
|
||||
else
|
||||
{
|
||||
B_KS = p * phi_m * mpf_class(this->sigma)
|
||||
* (pow(n, 2.5) * (1.49 * sqrt(h * phi_m) + 2.11 * h)
|
||||
+ 2.77 * n * n * sqrt(h)
|
||||
+ pow(n, 1.5) * (1.96 * sqrt(phi_m) * 2.77 * sqrt(h))
|
||||
+ 4.62 * n);
|
||||
}
|
||||
#ifdef NOISY
|
||||
cout << "p size: " << numBits(p) << endl;
|
||||
cout << "phi(m): " << phi_m << endl;
|
||||
cout << "n: " << n << endl;
|
||||
cout << "sec: " << sec << endl;
|
||||
cout << "sigma: " << sigma << endl;
|
||||
cout << "sigma: " << this->sigma << endl;
|
||||
cout << "h: " << h << endl;
|
||||
cout << "B_clean size: " << numBits(B_clean) << endl;
|
||||
cout << "B_scale size: " << numBits(B_scale) << endl;
|
||||
|
||||
@@ -13,27 +13,33 @@ int phi_N(int N);
|
||||
class SemiHomomorphicNoiseBounds
|
||||
{
|
||||
protected:
|
||||
static const int FHE_epsilon = 55;
|
||||
|
||||
const bigint p;
|
||||
const int phi_m;
|
||||
const int n;
|
||||
const int sec;
|
||||
int slack;
|
||||
const double sigma;
|
||||
mpf_class sigma;
|
||||
const int h;
|
||||
|
||||
bigint B_clean;
|
||||
bigint B_scale;
|
||||
bigint drown;
|
||||
|
||||
mpf_class c1, c2;
|
||||
|
||||
void produce_epsilon_constants();
|
||||
|
||||
public:
|
||||
SemiHomomorphicNoiseBounds(const bigint& p, int phi_m, int n, int sec,
|
||||
int slack, bool extra_h = false, double sigma = 3.2, int h = 64);
|
||||
int slack, bool extra_h = false, double sigma = -1, int h = 64);
|
||||
// with scaling
|
||||
bigint min_p0(const bigint& p1);
|
||||
// without scaling
|
||||
bigint min_p0();
|
||||
bigint min_p0(bool scale, const bigint& p1) { return scale ? min_p0(p1) : min_p0(); }
|
||||
static double min_phi_m(int log_q);
|
||||
static double min_phi_m(int log_q, double sigma = -1);
|
||||
};
|
||||
|
||||
// as per ePrint 2012:642 for slack = 0
|
||||
@@ -43,7 +49,7 @@ class NoiseBounds : public SemiHomomorphicNoiseBounds
|
||||
|
||||
public:
|
||||
NoiseBounds(const bigint& p, int phi_m, int n, int sec, int slack,
|
||||
double sigma = 3.2, int h = 64);
|
||||
double sigma = -1, int h = 64);
|
||||
bigint U1(const bigint& p0, const bigint& p1);
|
||||
bigint U2(const bigint& p0, const bigint& p1);
|
||||
bigint min_p0(const bigint& p0, const bigint& p1);
|
||||
|
||||
@@ -20,6 +20,7 @@ class P2Data
|
||||
public:
|
||||
typedef gf2n_short T;
|
||||
typedef int S;
|
||||
typedef int poly_type;
|
||||
|
||||
int num_slots() const { return slots; }
|
||||
int degree() const { return A.size() ? A.size() : 0; }
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#include "Math/Subroutines.h"
|
||||
#include "FHE/Subroutines.h"
|
||||
#include "FHE/PPData.h"
|
||||
#include "FHE/FFT.h"
|
||||
#include "FHE/Matrix.h"
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "Math/modp.h"
|
||||
#include "Math/Zp_Data.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/fixint.h"
|
||||
#include "FHE/Ring.h"
|
||||
|
||||
/* Class for holding modular arithmetic data wrt the ring
|
||||
@@ -16,6 +17,7 @@ class PPData
|
||||
public:
|
||||
typedef gf2n_short T;
|
||||
typedef bigint S;
|
||||
typedef fixint<GFP_MOD_SZ> poly_type;
|
||||
|
||||
Ring R;
|
||||
Zp_Data prData;
|
||||
|
||||
@@ -3,6 +3,21 @@
|
||||
#include "FHE/Ring_Element.h"
|
||||
#include "FHE/PPData.h"
|
||||
#include "FHE/P2Data.h"
|
||||
#include "FHE/Rq_Element.h"
|
||||
#include "FHE_Keys.h"
|
||||
#include "Math/Z2k.hpp"
|
||||
|
||||
|
||||
|
||||
template<>
|
||||
void Plaintext<gfp, FFT_Data, bigint>::from(const Generator<bigint>& source) const
|
||||
{
|
||||
for (auto& x : b)
|
||||
{
|
||||
source.get(bigint::tmp);
|
||||
x = bigint::tmp;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
@@ -11,7 +26,7 @@ void Plaintext<gfp,FFT_Data,bigint>::from_poly() const
|
||||
if (type!=Polynomial) { return; }
|
||||
|
||||
Ring_Element e(*Field_Data,polynomial);
|
||||
e.from_vec(b);
|
||||
e.from(b);
|
||||
e.change_rep(evaluation);
|
||||
a.resize(n_slots);
|
||||
for (unsigned int i=0; i<a.size(); i++)
|
||||
@@ -29,7 +44,7 @@ void Plaintext<gfp,FFT_Data,bigint>::to_poly() const
|
||||
for (unsigned int i=0; i<a.size(); i++)
|
||||
{ e.set_element(i,a[i].get()); }
|
||||
e.change_rep(polynomial);
|
||||
e.to_vec_bigint(b);
|
||||
from(e.get_iterator());
|
||||
type=Both;
|
||||
}
|
||||
|
||||
@@ -59,7 +74,10 @@ void Plaintext<gfp,PPData,bigint>::to_poly() const
|
||||
{ bb[i]=a[i].get(); }
|
||||
(*Field_Data).from_eval(bb);
|
||||
for (unsigned int i=0; i<bb.size(); i++)
|
||||
{ to_bigint(b[i],bb[i],(*Field_Data).prData); }
|
||||
{
|
||||
to_bigint(bigint::tmp,bb[i],(*Field_Data).prData);
|
||||
b[i] = bigint::tmp;
|
||||
}
|
||||
type=Both;
|
||||
}
|
||||
|
||||
@@ -122,7 +140,7 @@ void Plaintext<T, FD, S>::allocate_slots(const bigint& value)
|
||||
{
|
||||
b.resize(degree);
|
||||
for (auto& x : b)
|
||||
x = value;
|
||||
x.allocate_slots(value);
|
||||
}
|
||||
|
||||
template<>
|
||||
@@ -151,52 +169,20 @@ void signed_mod(bigint& x, const bigint& mod, const bigint& half_mod, const bigi
|
||||
x += dest_mod;
|
||||
}
|
||||
|
||||
template<>
|
||||
void Plaintext<gfp,FFT_Data,bigint>::set_poly_mod(const Generator<bigint>& generator,const bigint& mod)
|
||||
template<class T, class FD, class S>
|
||||
void Plaintext<T, FD, S>::set_poly_mod(const Generator<bigint>& generator,const bigint& mod)
|
||||
{
|
||||
allocate(Polynomial);
|
||||
bigint half_mod = mod / 2;
|
||||
for (unsigned int i=0; i<b.size(); i++)
|
||||
{
|
||||
generator.get(b[i]);
|
||||
signed_mod(b[i], mod, half_mod, Field_Data->get_prime());
|
||||
generator.get(bigint::tmp);
|
||||
signed_mod(bigint::tmp, mod, half_mod, Field_Data->get_prime());
|
||||
b[i] = bigint::tmp;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
void Plaintext<gfp,FFT_Data,bigint>::set_poly_mod(const vector<bigint>& vv,const bigint& mod)
|
||||
{
|
||||
b = vv;
|
||||
vector<bigint>& pol = b;
|
||||
bigint half_mod = mod / 2;
|
||||
for (unsigned int i=0; i<vv.size(); i++)
|
||||
{
|
||||
pol[i] = vv[i];
|
||||
if (pol[i] > half_mod)
|
||||
pol[i] -= mod;
|
||||
pol[i] %= (*Field_Data).get_prime();
|
||||
if (pol[i]<0) { pol[i]+=(*Field_Data).get_prime(); }
|
||||
}
|
||||
type = Polynomial;
|
||||
}
|
||||
|
||||
|
||||
template<>
|
||||
void Plaintext<gfp,PPData,bigint>::set_poly_mod(const vector<bigint>& vv,const bigint& mod)
|
||||
{
|
||||
b = vv;
|
||||
vector<bigint>& pol = b;
|
||||
for (unsigned int i=0; i<vv.size(); i++)
|
||||
{ if (vv[i]>mod/2) { pol[i]=vv[i]-mod; }
|
||||
else { pol[i]=vv[i]; }
|
||||
pol[i]=pol[i]%(*Field_Data).get_prime();
|
||||
if (pol[i]<0) { pol[i]+=(*Field_Data).get_prime(); }
|
||||
}
|
||||
type = Polynomial;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template<>
|
||||
@@ -229,7 +215,8 @@ void Plaintext<gf2n_short,P2Data,int>::set_poly_mod(const Generator<bigint>& gen
|
||||
}
|
||||
|
||||
|
||||
void rand_poly(vector<bigint>& b,PRNG& G,const bigint& pr,bool positive=true)
|
||||
template<class T>
|
||||
void rand_poly(vector<T>& b,PRNG& G,const bigint& pr,bool positive=true)
|
||||
{
|
||||
for (unsigned int i=0; i<b.size(); i++)
|
||||
{
|
||||
@@ -416,7 +403,7 @@ void Plaintext<T,FD,S>::assign_constant(T constant, PT_Type t)
|
||||
|
||||
template<class T,class FD,class S>
|
||||
Plaintext<T,FD,S>& Plaintext<T,FD,S>::operator+=(
|
||||
const Plaintext<T,FD,S>& y)
|
||||
const Plaintext& y)
|
||||
{
|
||||
if (Field_Data!=y.Field_Data) { throw field_mismatch(); }
|
||||
|
||||
@@ -687,8 +674,16 @@ void Plaintext<gf2n_short,P2Data,int>::negate()
|
||||
|
||||
|
||||
|
||||
template<class T, class FD, class S>
|
||||
Rq_Element Plaintext<T, FD, S>::mul_by_X_i(int i, const FHE_PK& pk) const
|
||||
{
|
||||
return Rq_Element(pk.get_params(), *this).mul_by_X_i(i);
|
||||
}
|
||||
|
||||
|
||||
|
||||
template<class T,class FD,class S>
|
||||
bool Plaintext<T,FD,S>::equals(const Plaintext<T,FD,S>& x) const
|
||||
bool Plaintext<T,FD,S>::equals(const Plaintext& x) const
|
||||
{
|
||||
if (Field_Data!=x.Field_Data) { return false; }
|
||||
if (type!=x.type)
|
||||
@@ -730,7 +725,7 @@ void Plaintext<T,FD,S>::unpack(octetStream& o)
|
||||
if (size != b.size())
|
||||
throw length_error("unexpected length received");
|
||||
for (unsigned int i = 0; i < b.size(); i++)
|
||||
o.get(b[i]);
|
||||
b[i] = o.get<S>();
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -18,10 +18,14 @@
|
||||
*/
|
||||
|
||||
#include "FHE/Generator.h"
|
||||
#include "Math/fixint.h"
|
||||
|
||||
#include <vector>
|
||||
using namespace std;
|
||||
|
||||
class FHE_PK;
|
||||
class Rq_Element;
|
||||
|
||||
// Forward declaration as apparently this is needed for friends in templates
|
||||
template<class T,class FD,class S> class Plaintext;
|
||||
template<class T,class FD,class S> ostream& operator<<(ostream& s,const Plaintext<T,FD,S>& e);
|
||||
@@ -35,9 +39,11 @@ enum condition { Full, Diagonal, Bits };
|
||||
|
||||
enum PT_Type { Polynomial, Evaluation, Both };
|
||||
|
||||
template<class T,class FD,class S>
|
||||
template<class T,class FD,class _>
|
||||
class Plaintext
|
||||
{
|
||||
typedef typename FD::poly_type S;
|
||||
|
||||
int n_slots;
|
||||
int degree;
|
||||
|
||||
@@ -60,7 +66,7 @@ class Plaintext
|
||||
const FD& get_field() const { return *Field_Data; }
|
||||
unsigned int num_slots() const { return n_slots; }
|
||||
|
||||
void assign(const Plaintext<T,FD,S>& p)
|
||||
void assign(const Plaintext& p)
|
||||
{ Field_Data=p.Field_Data;
|
||||
a=p.a; b=p.b; type=p.type;
|
||||
n_slots = p.n_slots;
|
||||
@@ -70,12 +76,7 @@ class Plaintext
|
||||
Plaintext(const FD& FieldD, PT_Type type = Polynomial)
|
||||
{ Field_Data=&FieldD; set_sizes(); allocate(type); }
|
||||
|
||||
Plaintext(const Plaintext<T,FD,S>& p) { assign(p); }
|
||||
~Plaintext() { ; }
|
||||
Plaintext& operator=(const Plaintext<T,FD,S>& p)
|
||||
{ if (this!=&p) { assign(p); }
|
||||
return *this;
|
||||
}
|
||||
Plaintext(const FD& FieldD, const Rq_Element& other);
|
||||
|
||||
void allocate(PT_Type type) const;
|
||||
void allocate() const { allocate(type); }
|
||||
@@ -117,17 +118,23 @@ class Plaintext
|
||||
void set_poly(const vector<S>& v)
|
||||
{ type=Polynomial; b=v; }
|
||||
const vector<S>& get_poly() const
|
||||
{ if (type==Evaluation) { throw rep_mismatch(); }
|
||||
{
|
||||
to_poly();
|
||||
return b;
|
||||
}
|
||||
|
||||
Iterator<S> get_iterator() const { to_poly(); return b; }
|
||||
|
||||
void from(const Generator<bigint>& source) const;
|
||||
|
||||
// This sets a poly from a vector of bigint's which needs centering
|
||||
// modulo mod, before assigning (used in decryption)
|
||||
// vv[i] is already assumed reduced modulo mod though but in
|
||||
// range [0,...,mod)
|
||||
void set_poly_mod(const vector<bigint>& vv,const bigint& mod);
|
||||
void set_poly_mod(const vector<bigint>& vv,const bigint& mod)
|
||||
{
|
||||
set_poly_mod(Iterator<bigint>(vv), mod);
|
||||
}
|
||||
void set_poly_mod(const Generator<bigint>& generator, const bigint& mod);
|
||||
|
||||
// Converts between Evaluation,Polynomial and Both representations
|
||||
@@ -144,36 +151,38 @@ class Plaintext
|
||||
void assign_one(PT_Type t = Evaluation);
|
||||
void assign_constant(T constant, PT_Type t = Evaluation);
|
||||
|
||||
friend void add<>(Plaintext<T,FD,S>& z,const Plaintext<T,FD,S>& x,const Plaintext<T,FD,S>& y);
|
||||
friend void sub<>(Plaintext<T,FD,S>& z,const Plaintext<T,FD,S>& x,const Plaintext<T,FD,S>& y);
|
||||
friend void mul<>(Plaintext<T,FD,S>& z,const Plaintext<T,FD,S>& x,const Plaintext<T,FD,S>& y);
|
||||
friend void sqr<>(Plaintext<T,FD,S>& z,const Plaintext<T,FD,S>& x);
|
||||
friend void add<>(Plaintext& z,const Plaintext& x,const Plaintext& y);
|
||||
friend void sub<>(Plaintext& z,const Plaintext& x,const Plaintext& y);
|
||||
friend void mul<>(Plaintext& z,const Plaintext& x,const Plaintext& y);
|
||||
friend void sqr<>(Plaintext& z,const Plaintext& x);
|
||||
|
||||
Plaintext<T,FD,S> operator+(const Plaintext<T,FD,S>& x) const
|
||||
{ Plaintext<T,FD,S> res(*Field_Data); add(res, *this, x); return res; }
|
||||
Plaintext<T,FD,S> operator-(const Plaintext<T,FD,S>& x) const
|
||||
{ Plaintext<T,FD,S> res(*Field_Data); sub(res, *this, x); return res; }
|
||||
Plaintext operator+(const Plaintext& x) const
|
||||
{ Plaintext res(*Field_Data); add(res, *this, x); return res; }
|
||||
Plaintext operator-(const Plaintext& x) const
|
||||
{ Plaintext res(*Field_Data); sub(res, *this, x); return res; }
|
||||
|
||||
void mul(const Plaintext<T,FD,S>& x, const Plaintext<T,FD,S>& y)
|
||||
void mul(const Plaintext& x, const Plaintext& y)
|
||||
{ x.from_poly(); y.from_poly(); ::mul(*this, x, y); }
|
||||
|
||||
Plaintext<T,FD,S> operator*(const Plaintext<T,FD,S>& x)
|
||||
{ Plaintext<T,FD,S> res(*Field_Data); res.mul(*this, x); return res; }
|
||||
Plaintext operator*(const Plaintext& x)
|
||||
{ Plaintext res(*Field_Data); res.mul(*this, x); return res; }
|
||||
|
||||
Plaintext<T,FD,S>& operator+=(const Plaintext<T,FD,S>& y);
|
||||
Plaintext<T,FD,S>& operator-=(const Plaintext<T,FD,S>& y)
|
||||
Plaintext& operator+=(const Plaintext& y);
|
||||
Plaintext& operator-=(const Plaintext& y)
|
||||
{ to_poly(); y.to_poly(); ::sub(*this, *this, y); return *this; }
|
||||
|
||||
void negate();
|
||||
|
||||
bool equals(const Plaintext<T,FD,S>& x) const;
|
||||
bool operator!=(const Plaintext<T,FD,S>& x) { return !equals(x); }
|
||||
Rq_Element mul_by_X_i(int i, const FHE_PK& pk) const;
|
||||
|
||||
bool equals(const Plaintext& x) const;
|
||||
bool operator!=(const Plaintext& x) { return !equals(x); }
|
||||
|
||||
bool is_diagonal() const { throw not_implemented(); }
|
||||
bool is_binary() const { throw not_implemented(); }
|
||||
|
||||
friend ostream& operator<< <>(ostream& s,const Plaintext<T,FD,S>& e);
|
||||
friend istream& operator>> <>(istream& s,Plaintext<T,FD,S>& e);
|
||||
friend ostream& operator<< <>(ostream& s,const Plaintext& e);
|
||||
friend istream& operator>> <>(istream& s,Plaintext& e);
|
||||
|
||||
/* Pack and unpack into an octetStream
|
||||
* For unpack we assume the FFTD has been assigned correctly already
|
||||
|
||||
@@ -9,8 +9,10 @@
|
||||
|
||||
class FHE_PK;
|
||||
|
||||
class Int_Random_Coins : public AddableMatrix<bigint>
|
||||
class Int_Random_Coins : public AddableMatrix<fixint<0>>
|
||||
{
|
||||
typedef value_type::value_type T;
|
||||
|
||||
const FHE_Params* params;
|
||||
public:
|
||||
Int_Random_Coins(const FHE_Params& params) : params(¶ms)
|
||||
@@ -20,14 +22,16 @@ public:
|
||||
|
||||
void sample(PRNG& G)
|
||||
{
|
||||
(*this)[0].from(HalfGenerator(G));
|
||||
(*this)[0].from(HalfGenerator<T>(G));
|
||||
for (int i = 1; i < 3; i++)
|
||||
(*this)[i].from(GaussianGenerator(params->get_DG(), G));
|
||||
(*this)[i].from(GaussianGenerator<T>(params->get_DG(), G));
|
||||
}
|
||||
};
|
||||
|
||||
class Random_Coins
|
||||
{
|
||||
typedef bigint T;
|
||||
|
||||
Rq_Element uu,vv,ww;
|
||||
const FHE_Params *params;
|
||||
|
||||
@@ -56,16 +60,25 @@ class Random_Coins
|
||||
|
||||
template <class T>
|
||||
void assign(const vector<T>& u,const vector<T>& v,const vector<T>& w)
|
||||
{ uu.from_vec(u); vv.from_vec(v); ww.from_vec(w); }
|
||||
{
|
||||
uu.from(u);
|
||||
vv.from(v);
|
||||
ww.from(w);
|
||||
}
|
||||
|
||||
void assign(const Int_Random_Coins& rc)
|
||||
{ uu.from_vec(rc[0]); vv.from_vec(rc[1]); ww.from_vec(rc[2]); }
|
||||
{
|
||||
uu.from(rc[0]);
|
||||
vv.from(rc[1]);
|
||||
ww.from(rc[2]);
|
||||
}
|
||||
|
||||
/* Generate a standard distribution */
|
||||
void generate(PRNG& G)
|
||||
{ uu.from(HalfGenerator(G));
|
||||
vv.from(GaussianGenerator(params->get_DG(), G));
|
||||
ww.from(GaussianGenerator(params->get_DG(), G));
|
||||
{
|
||||
uu.from(HalfGenerator<T>(G));
|
||||
vv.from(GaussianGenerator<T>(params->get_DG(), G));
|
||||
ww.from(GaussianGenerator<T>(params->get_DG(), G));
|
||||
}
|
||||
|
||||
// Generate all from Uniform in range (-B,...B)
|
||||
@@ -74,9 +87,9 @@ class Random_Coins
|
||||
if (B1 == 0)
|
||||
uu.assign_zero();
|
||||
else
|
||||
uu.from(UniformGenerator(G,numBits(B1)));
|
||||
vv.from(UniformGenerator(G,numBits(B2)));
|
||||
ww.from(UniformGenerator(G,numBits(B3)));
|
||||
uu.from(UniformGenerator<T>(G,numBits(B1)));
|
||||
vv.from(UniformGenerator<T>(G,numBits(B2)));
|
||||
ww.from(UniformGenerator<T>(G,numBits(B3)));
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -147,6 +147,47 @@ void mul(Ring_Element& ans,const Ring_Element& a,const modp& b)
|
||||
}
|
||||
|
||||
|
||||
Ring_Element Ring_Element::mul_by_X_i(int j) const
|
||||
{
|
||||
Ring_Element ans;
|
||||
auto& a = *this;
|
||||
ans.partial_assign(a);
|
||||
if (ans.rep == evaluation)
|
||||
{
|
||||
modp xj, xj2;
|
||||
Power(xj, (*ans.FFTD).get_root(0), j, (*a.FFTD).get_prD());
|
||||
Sqr(xj2, xj, (*a.FFTD).get_prD());
|
||||
for (int i= 0; i < (*ans.FFTD).phi_m(); i++)
|
||||
{
|
||||
Mul(ans.element[i], a.element[i], xj, (*a.FFTD).get_prD());
|
||||
Mul(xj, xj, xj2, (*a.FFTD).get_prD());
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
Ring_Element aa(*ans.FFTD, ans.rep);
|
||||
for (int i= 0; i < (*ans.FFTD).phi_m(); i++)
|
||||
{
|
||||
int k= j + i, s= 1;
|
||||
while (k >= (*ans.FFTD).phi_m())
|
||||
{
|
||||
k-= (*ans.FFTD).phi_m();
|
||||
s= -s;
|
||||
}
|
||||
if (s == 1)
|
||||
{
|
||||
aa.element[k]= a.element[i];
|
||||
}
|
||||
else
|
||||
{
|
||||
Negate(aa.element[k], a.element[i], (*a.FFTD).get_prD());
|
||||
}
|
||||
}
|
||||
ans= aa;
|
||||
}
|
||||
return ans;
|
||||
}
|
||||
|
||||
|
||||
void Ring_Element::randomize(PRNG& G,bool Diag)
|
||||
{
|
||||
@@ -318,20 +359,6 @@ void Ring_Element::from_vec(const vector<int>& v)
|
||||
// cout << "RE:from_vec<int>:: " << *this << endl;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Ring_Element::from(const Generator<T>& generator)
|
||||
{
|
||||
RepType t=rep;
|
||||
rep=polynomial;
|
||||
T tmp;
|
||||
for (int i=0; i<(*FFTD).phi_m(); i++)
|
||||
{
|
||||
generator.get(tmp);
|
||||
element[i].convert_destroy(tmp, (*FFTD).get_prD());
|
||||
}
|
||||
change_rep(t);
|
||||
}
|
||||
|
||||
ConversionIterator Ring_Element::get_iterator() const
|
||||
{
|
||||
if (rep != polynomial)
|
||||
@@ -389,6 +416,7 @@ modp Ring_Element::get_constant() const
|
||||
|
||||
void store(octetStream& o,const vector<modp>& v,const Zp_Data& ZpD)
|
||||
{
|
||||
ZpD.pack(o);
|
||||
o.store((int)v.size());
|
||||
for (unsigned int i=0; i<v.size(); i++)
|
||||
{ v[i].pack(o,ZpD); }
|
||||
@@ -397,6 +425,12 @@ void store(octetStream& o,const vector<modp>& v,const Zp_Data& ZpD)
|
||||
|
||||
void get(octetStream& o,vector<modp>& v,const Zp_Data& ZpD)
|
||||
{
|
||||
Zp_Data check_Zpd;
|
||||
check_Zpd.unpack(o);
|
||||
if (check_Zpd != ZpD)
|
||||
throw runtime_error(
|
||||
"mismatch: " + to_string(check_Zpd.pr_bit_length) + "/"
|
||||
+ to_string(ZpD.pr_bit_length));
|
||||
unsigned int length;
|
||||
o.get(length);
|
||||
v.resize(length);
|
||||
@@ -408,7 +442,7 @@ void get(octetStream& o,vector<modp>& v,const Zp_Data& ZpD)
|
||||
void Ring_Element::pack(octetStream& o) const
|
||||
{
|
||||
check_size();
|
||||
o.store(rep);
|
||||
o.store(unsigned(rep));
|
||||
store(o,element,(*FFTD).get_prD());
|
||||
}
|
||||
|
||||
|
||||
@@ -90,6 +90,8 @@ class Ring_Element
|
||||
friend void mul(Ring_Element& ans,const Ring_Element& a,const Ring_Element& b);
|
||||
friend void mul(Ring_Element& ans,const Ring_Element& a,const modp& b);
|
||||
|
||||
Ring_Element mul_by_X_i(int i) const;
|
||||
|
||||
void randomize(PRNG& G,bool Diag=false);
|
||||
|
||||
bool equals(const Ring_Element& a) const;
|
||||
@@ -116,6 +118,12 @@ class Ring_Element
|
||||
template <class T>
|
||||
void from(const Generator<T>& generator);
|
||||
|
||||
template <class T>
|
||||
void from(const vector<T>& source)
|
||||
{
|
||||
from(Iterator<T>(source));
|
||||
}
|
||||
|
||||
// This gets the constant term of the poly rep as a modp element
|
||||
modp get_constant() const;
|
||||
modp get_element(int i) const { return element[i]; }
|
||||
@@ -167,5 +175,20 @@ public:
|
||||
inline void mul(Ring_Element& ans,const modp& a,const Ring_Element& b)
|
||||
{ mul(ans,b,a); }
|
||||
|
||||
|
||||
template <class T>
|
||||
void Ring_Element::from(const Generator<T>& generator)
|
||||
{
|
||||
RepType t=rep;
|
||||
rep=polynomial;
|
||||
T tmp;
|
||||
for (int i=0; i<(*FFTD).phi_m(); i++)
|
||||
{
|
||||
generator.get(tmp);
|
||||
element[i].convert_destroy(tmp, (*FFTD).get_prD());
|
||||
}
|
||||
change_rep(t);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
#include "Rq_Element.h"
|
||||
#include "FHE_Keys.h"
|
||||
#include "Exceptions/Exceptions.h"
|
||||
|
||||
Rq_Element::Rq_Element(const FHE_PK& pk) :
|
||||
Rq_Element(pk.get_params().FFTD())
|
||||
{
|
||||
}
|
||||
|
||||
Rq_Element::Rq_Element(const vector<FFT_Data>& prd, RepType r0, RepType r1)
|
||||
{
|
||||
if (prd.size() > 0)
|
||||
@@ -57,6 +63,19 @@ void Rq_Element::negate()
|
||||
a[i].negate();
|
||||
}
|
||||
|
||||
Rq_Element Rq_Element::mul_by_X_i(int i) const
|
||||
{
|
||||
Rq_Element res;
|
||||
res.lev = lev;
|
||||
res.a.clear();
|
||||
for (auto& x : a)
|
||||
{
|
||||
auto tmp = x.mul_by_X_i(i);
|
||||
res.a.push_back(tmp);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
void add(Rq_Element& ans,const Rq_Element& ra,const Rq_Element& rb)
|
||||
{
|
||||
ans.partial_assign(ra, rb);
|
||||
@@ -173,19 +192,6 @@ void Rq_Element::from_vec(const vector<int>& v,int level)
|
||||
a[i].from_vec(v);
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Rq_Element::from(const Generator<T>& generator, int level)
|
||||
{
|
||||
set_level(level);
|
||||
if (lev == 1)
|
||||
{
|
||||
auto clone = generator.clone();
|
||||
a[1].from(*clone);
|
||||
delete clone;
|
||||
}
|
||||
a[0].from(generator);
|
||||
}
|
||||
|
||||
vector<bigint> Rq_Element::to_vec_bigint() const
|
||||
{
|
||||
vector<bigint> v;
|
||||
@@ -220,7 +226,7 @@ void Rq_Element::to_vec_bigint(vector<bigint>& v) const
|
||||
}
|
||||
}
|
||||
|
||||
ConversionIterator Rq_Element::get_iterator()
|
||||
ConversionIterator Rq_Element::get_iterator() const
|
||||
{
|
||||
if (lev != 0)
|
||||
throw not_implemented();
|
||||
@@ -339,7 +345,8 @@ void Rq_Element::raise_level()
|
||||
void Rq_Element::check_level() const
|
||||
{
|
||||
if ((unsigned)lev > (unsigned)n_mults())
|
||||
throw range_error("level out of range");
|
||||
throw range_error(
|
||||
"level out of range: " + to_string(lev) + "/" + to_string(n_mults()));
|
||||
}
|
||||
|
||||
void Rq_Element::partial_assign(const Rq_Element& x, const Rq_Element& y)
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
#include "FHE/FHE_Params.h"
|
||||
#include "FHE/tools.h"
|
||||
#include "FHE/Generator.h"
|
||||
#include "Plaintext.h"
|
||||
#include <vector>
|
||||
|
||||
// Forward declare the friend functions
|
||||
@@ -62,9 +63,18 @@ protected:
|
||||
Rq_Element(const FHE_Params& params) :
|
||||
Rq_Element(params.FFTD()) {}
|
||||
|
||||
Rq_Element(const FHE_PK& pk);
|
||||
|
||||
Rq_Element(const Ring_Element& b0,const Ring_Element& b1) :
|
||||
a({b0, b1}), lev(n_mults()) {}
|
||||
|
||||
template<class T, class FD, class S>
|
||||
Rq_Element(const FHE_Params& params, const Plaintext<T, FD, S>& plaintext) :
|
||||
Rq_Element(params)
|
||||
{
|
||||
from(plaintext.get_iterator());
|
||||
}
|
||||
|
||||
// Destructor
|
||||
~Rq_Element()
|
||||
{ ; }
|
||||
@@ -91,6 +101,8 @@ protected:
|
||||
// Multiply something by p1 and make level 1
|
||||
void mul_by_p1();
|
||||
|
||||
Rq_Element mul_by_X_i(int i) const;
|
||||
|
||||
void randomize(PRNG& G,int lev=-1);
|
||||
|
||||
// Scale from level 1 to level 0, if at level 0 do nothing
|
||||
@@ -113,10 +125,16 @@ protected:
|
||||
vector<bigint> to_vec_bigint() const;
|
||||
void to_vec_bigint(vector<bigint>& v) const;
|
||||
|
||||
ConversionIterator get_iterator();
|
||||
ConversionIterator get_iterator() const;
|
||||
template <class T>
|
||||
void from(const Generator<T>& generator, int level=-1);
|
||||
|
||||
template <class T>
|
||||
void from(const vector<T>& source, int level=-1)
|
||||
{
|
||||
from(Iterator<T>(source), level);
|
||||
}
|
||||
|
||||
bigint infinity_norm() const;
|
||||
|
||||
bigint get_prime(int i) const
|
||||
@@ -156,9 +174,22 @@ template<class S>
|
||||
Rq_Element& Rq_Element::operator+=(const vector<S>& other)
|
||||
{
|
||||
Rq_Element tmp = *this;
|
||||
tmp.from_vec(other, lev);
|
||||
tmp.from(Iterator<S>(other), lev);
|
||||
add(*this, *this, tmp);
|
||||
return *this;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Rq_Element::from(const Generator<T>& generator, int level)
|
||||
{
|
||||
set_level(level);
|
||||
if (lev == 1)
|
||||
{
|
||||
auto clone = generator.clone();
|
||||
a[1].from(*clone);
|
||||
delete clone;
|
||||
}
|
||||
a[0].from(generator);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
|
||||
#include "Subroutines.h"
|
||||
#include "modp.h"
|
||||
#include "Math/modp.h"
|
||||
|
||||
#include "modp.hpp"
|
||||
#include "Math/modp.hpp"
|
||||
|
||||
void Subs(modp& ans,const vector<int>& poly,const modp& x,const Zp_Data& ZpD)
|
||||
{
|
||||
@@ -18,6 +18,11 @@ CutAndChooseMachine::CutAndChooseMachine(int argc, const char** argv)
|
||||
"--covert" // Flag token.
|
||||
);
|
||||
parse_options(argc, argv);
|
||||
if (produce_inputs)
|
||||
{
|
||||
cerr << "Producing input tuples is not implemented" << endl;
|
||||
exit(1);
|
||||
}
|
||||
covert = opt.isSet("--covert");
|
||||
if (not covert and sec != 40)
|
||||
throw runtime_error("active cut-and-choose only implemented for 40-bit security");
|
||||
|
||||
@@ -8,6 +8,10 @@
|
||||
#include "Protocols/fake-stuff.h"
|
||||
#include "FHE/NTL-Subs.h"
|
||||
#include "Tools/benchmarking.h"
|
||||
#include "Tools/Bundle.h"
|
||||
#include "PairwiseSetup.h"
|
||||
#include "Proof.h"
|
||||
#include "SimpleMachine.h"
|
||||
|
||||
#include <iostream>
|
||||
using namespace std;
|
||||
@@ -58,8 +62,8 @@ void PartSetup<FD>::generate_setup(int n_parties, int plaintext_length, int sec,
|
||||
int slack, bool round_up)
|
||||
{
|
||||
sec = max(sec, 40);
|
||||
::generate_setup(n_parties, plaintext_length, sec, params, FieldD,
|
||||
slack, round_up);
|
||||
Parameters(n_parties, plaintext_length, sec, slack, round_up).generate_setup(
|
||||
params, FieldD);
|
||||
params.set_sec(sec);
|
||||
pk = FHE_PK(params, FieldD.get_prime());
|
||||
sk = FHE_SK(params, FieldD.get_prime());
|
||||
@@ -254,6 +258,7 @@ void DataSetup::output(int my_number, int nn, bool specific_dir)
|
||||
template <class FD>
|
||||
void PartSetup<FD>::pack(octetStream& os)
|
||||
{
|
||||
os.append((octet*)"PARTSETU", 8);
|
||||
params.pack(os);
|
||||
FieldD.pack(os);
|
||||
pk.pack(os);
|
||||
@@ -265,8 +270,15 @@ void PartSetup<FD>::pack(octetStream& os)
|
||||
template <class FD>
|
||||
void PartSetup<FD>::unpack(octetStream& os)
|
||||
{
|
||||
char tag[8];
|
||||
os.consume((octet*) tag, 8);
|
||||
if (memcmp(tag, "PARTSETU", 8))
|
||||
throw runtime_error("invalid serialization of setup");
|
||||
params.unpack(os);
|
||||
FieldD.unpack(os);
|
||||
pk = {params, FieldD};
|
||||
sk = pk;
|
||||
calpha = params;
|
||||
pk.unpack(os);
|
||||
sk.unpack(os);
|
||||
calpha.unpack(os);
|
||||
@@ -305,5 +317,94 @@ bool PartSetup<FD>::operator!=(const PartSetup<FD>& other)
|
||||
return false;
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
void PartSetup<FD>::secure_init(Player& P, MachineBase& machine,
|
||||
int plaintext_length, int sec)
|
||||
{
|
||||
::secure_init(*this, P, machine, plaintext_length, sec);
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
void PartSetup<FD>::generate(Player& P, MachineBase&, int plaintext_length,
|
||||
int sec)
|
||||
{
|
||||
generate_setup(P.num_players(), plaintext_length, sec,
|
||||
INTERACTIVE_SPDZ1_SLACK, false);
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
void PartSetup<FD>::check(Player& P, MachineBase& machine)
|
||||
{
|
||||
Bundle<octetStream> bundle(P);
|
||||
bundle.mine.store(machine.extra_slack);
|
||||
auto& os = bundle.mine;
|
||||
params.pack(os);
|
||||
FieldD.hash(os);
|
||||
pk.pack(os);
|
||||
calpha.pack(os);
|
||||
bundle.compare(P);
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
void PartSetup<FD>::covert_key_generation(Player& P,
|
||||
MultiplicativeMachine& machine, int num_runs)
|
||||
{
|
||||
auto& setup = machine.setup.part<FD>();
|
||||
Run_Gen_Protocol(setup.pk, setup.sk, P, num_runs, false);
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
void PartSetup<FD>::covert_mac_generation(Player& P,
|
||||
MultiplicativeMachine& machine, int num_runs)
|
||||
{
|
||||
auto& setup = machine.setup.part<FD>();
|
||||
generate_mac_key(setup.alphai, setup.calpha, setup.FieldD, setup.pk, P,
|
||||
num_runs);
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
void PartSetup<FD>::covert_secrets_generation(Player& P,
|
||||
MultiplicativeMachine& machine, int num_runs)
|
||||
{
|
||||
octetStream os;
|
||||
params.pack(os);
|
||||
FieldD.pack(os);
|
||||
string filename = PREP_DIR "ChaiGear-Secrets-" + to_string(num_runs) + "-"
|
||||
+ os.check_sum(20).get_str(16) + "-P" + to_string(P.my_num());
|
||||
|
||||
string error;
|
||||
|
||||
try
|
||||
{
|
||||
ifstream input(filename);
|
||||
os.input(input);
|
||||
unpack(os);
|
||||
}
|
||||
catch (exception& e)
|
||||
{
|
||||
error = e.what();
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
check(P, machine);
|
||||
}
|
||||
catch (mismatch_among_parties& e)
|
||||
{
|
||||
error = e.what();
|
||||
}
|
||||
|
||||
if (not error.empty())
|
||||
{
|
||||
cerr << "Running secrets generation because " << error << endl;
|
||||
covert_key_generation(P, machine, num_runs);
|
||||
covert_mac_generation(P, machine, num_runs);
|
||||
ofstream output(filename);
|
||||
octetStream os;
|
||||
pack(os);
|
||||
os.output(output);
|
||||
}
|
||||
}
|
||||
|
||||
template class PartSetup<FFT_Data>;
|
||||
template class PartSetup<P2Data>;
|
||||
|
||||
@@ -12,6 +12,8 @@
|
||||
#include "Math/Setup.h"
|
||||
|
||||
class DataSetup;
|
||||
class MachineBase;
|
||||
class MultiplicativeMachine;
|
||||
|
||||
template <class FD>
|
||||
class PartSetup
|
||||
@@ -26,6 +28,11 @@ public:
|
||||
Ciphertext calpha;
|
||||
typename FD::T alphai;
|
||||
|
||||
static string name()
|
||||
{
|
||||
return "GlobalParams-" + T::type_string();
|
||||
}
|
||||
|
||||
PartSetup();
|
||||
void generate_setup(int n_parties, int plaintext_length, int sec, int slack,
|
||||
bool round_up);
|
||||
@@ -42,6 +49,19 @@ public:
|
||||
|
||||
void check(int sec) const;
|
||||
bool operator!=(const PartSetup<FD>& other);
|
||||
|
||||
void secure_init(Player& P, MachineBase& machine, int plaintext_length,
|
||||
int sec);
|
||||
void generate(Player& P, MachineBase& machine, int plaintext_length,
|
||||
int sec);
|
||||
void check(Player& P, MachineBase& machine);
|
||||
|
||||
void covert_key_generation(Player& P, MultiplicativeMachine& machine,
|
||||
int num_runs);
|
||||
void covert_mac_generation(Player& P, MultiplicativeMachine& machine,
|
||||
int num_runs);
|
||||
void covert_secrets_generation(Player& P, MultiplicativeMachine& machine,
|
||||
int num_runs);
|
||||
};
|
||||
|
||||
class DataSetup
|
||||
|
||||
@@ -13,7 +13,7 @@ DistDecrypt<FD>::DistDecrypt(const Player& P, const FHE_SK& share,
|
||||
bigint limit = pk.get_params().Q() << 64;
|
||||
vv.allocate_slots(limit);
|
||||
vv1.allocate_slots(limit);
|
||||
mf.allocate_slots(pk.get_params().p0() << 64);
|
||||
mf.allocate_slots(pk.p() << 64);
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
|
||||
@@ -262,7 +262,7 @@ void EncCommit<T,FD,S>::Create_More() const
|
||||
{ if (cond!=Full)
|
||||
{ throw not_implemented(); }
|
||||
else
|
||||
{ m_Delta.from(UniformGenerator(Gseed[i],numBits(Bound1))); }
|
||||
{ m_Delta.from(UniformGenerator<bigint>(Gseed[i],numBits(Bound1))); }
|
||||
rc_Delta.generateUniform(Gseed[i],Bound2,Bound3,Bound3);
|
||||
Ciphertext Delta(params);
|
||||
(*pk).quasi_encrypt(Delta,m_Delta,rc_Delta);
|
||||
@@ -319,7 +319,7 @@ void EncCommit<T,FD,S>::Create_More() const
|
||||
if (cond!=Full)
|
||||
{ throw not_implemented(); }
|
||||
else
|
||||
{ mm.from(UniformGenerator(G,numBits(Bound1))); }
|
||||
{ mm.from(UniformGenerator<bigint>(G, numBits(Bound1))); }
|
||||
rr.generateUniform(G,Bound2,Bound3,Bound3);
|
||||
(*pk).quasi_encrypt(cc,mm,rr);
|
||||
occ.reset_write_head();
|
||||
@@ -357,10 +357,10 @@ void EncCommit<T,FD,S>::Create_More() const
|
||||
{ throw not_implemented();
|
||||
}
|
||||
else
|
||||
{ m_Delta.from(UniformGenerator(G,numBits(Bound1))); }
|
||||
{ m_Delta.from(UniformGenerator<bigint>(G, numBits(Bound1))); }
|
||||
rc_Delta.generateUniform(G,Bound2,Bound3,Bound3);
|
||||
|
||||
Iterator<S> vm=m[i].get_iterator();
|
||||
auto vm=m[i].get_iterator();
|
||||
z[0].from(vm);
|
||||
add(z[0],z[0],m_Delta);
|
||||
add(rr,rc[i],rc_Delta);
|
||||
|
||||
@@ -12,13 +12,17 @@
|
||||
#include "FHE/Plaintext.h"
|
||||
#include "Tools/MemoryUsage.h"
|
||||
|
||||
class MachineBase;
|
||||
|
||||
template<class T,class FD,class S>
|
||||
class EncCommitBase
|
||||
{
|
||||
public:
|
||||
size_t volatile_memory;
|
||||
|
||||
EncCommitBase() : volatile_memory(0) {}
|
||||
const MachineBase* machine;
|
||||
|
||||
EncCommitBase(const MachineBase* machine = 0) : volatile_memory(0), machine(machine) {}
|
||||
virtual ~EncCommitBase() {}
|
||||
virtual condition get_condition() { return Full; }
|
||||
virtual void next(Plaintext<T,FD,S>& mess, Ciphertext& c)
|
||||
|
||||
@@ -27,7 +27,7 @@ Multiplier<FD>::Multiplier(int offset, PairwiseMachine& machine, Player& P,
|
||||
product_share(machine.setup<FD>().FieldD), rc(machine.pk),
|
||||
volatile_capacity(0)
|
||||
{
|
||||
product_share.allocate_slots(machine.setup<FD>().params.p0() << 64);
|
||||
product_share.allocate_slots(machine.pk.p() << 64);
|
||||
}
|
||||
|
||||
template <class FD>
|
||||
@@ -35,7 +35,7 @@ void Multiplier<FD>::multiply_and_add(Plaintext_<FD>& res,
|
||||
const Ciphertext& enc_a, const Plaintext_<FD>& b)
|
||||
{
|
||||
Rq_Element bb(enc_a.get_params(), evaluation, evaluation);
|
||||
bb.from_vec(b.get_poly());
|
||||
bb.from(b.get_iterator());
|
||||
multiply_and_add(res, enc_a, bb);
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
#include "FHEOffline/PairwiseMachine.h"
|
||||
#include "FHEOffline/Producer.h"
|
||||
#include "Protocols/SemiShare.h"
|
||||
#include "GC/SemiSecret.h"
|
||||
#include "GC/SemiPrep.h"
|
||||
|
||||
#include "Protocols/MAC_Check.hpp"
|
||||
#include "Protocols/SemiInput.hpp"
|
||||
@@ -21,20 +23,21 @@ PairwiseGenerator<FD>::PairwiseGenerator(int thread_num,
|
||||
thread_num, machine.output),
|
||||
EC(P, machine.other_pks, machine.setup<FD>().FieldD, timers, machine, *this),
|
||||
MC(machine.setup<FD>().alphai),
|
||||
C(machine.sec, machine.setup<FD>().params), volatile_memory(0),
|
||||
n_ciphertexts(Proof::n_ciphertext_per_proof(machine.sec, machine.pk)),
|
||||
C(n_ciphertexts, machine.setup<FD>().params), volatile_memory(0),
|
||||
machine(machine)
|
||||
{
|
||||
for (int i = 1; i < P.num_players(); i++)
|
||||
multipliers.push_back(new Multiplier<FD>(i, *this));
|
||||
const FD& FieldD = machine.setup<FD>().FieldD;
|
||||
a.resize(machine.sec, FieldD);
|
||||
b.resize(machine.sec, FieldD);
|
||||
c.resize(machine.sec, FieldD);
|
||||
a.resize(n_ciphertexts, FieldD);
|
||||
b.resize(n_ciphertexts, FieldD);
|
||||
c.resize(n_ciphertexts, FieldD);
|
||||
a.allocate_slots(FieldD.get_prime());
|
||||
b.allocate_slots(FieldD.get_prime());
|
||||
// extra limb for addition
|
||||
c.allocate_slots((bigint)FieldD.get_prime() << 64);
|
||||
b_mod_q.resize(machine.sec,
|
||||
b_mod_q.resize(n_ciphertexts,
|
||||
{ machine.setup<FD>().params, evaluation, evaluation });
|
||||
}
|
||||
|
||||
@@ -64,8 +67,8 @@ void PairwiseGenerator<FD>::run()
|
||||
c.mul(a, b);
|
||||
timers["Plaintext multiplication"].stop();
|
||||
timers["FFT of b"].start();
|
||||
for (int i = 0; i < machine.sec; i++)
|
||||
b_mod_q.at(i).from_vec(b.at(i).get_poly());
|
||||
for (int i = 0; i < n_ciphertexts; i++)
|
||||
b_mod_q.at(i).from(b.at(i).get_iterator());
|
||||
timers["FFT of b"].stop();
|
||||
timers["Proof exchange"].start();
|
||||
size_t verifier_memory = EC.create_more(ciphertexts, cleartexts);
|
||||
@@ -73,7 +76,7 @@ void PairwiseGenerator<FD>::run()
|
||||
volatile_memory = max(prover_memory, verifier_memory);
|
||||
|
||||
Rq_Element values({machine.setup<FD>().params, evaluation, evaluation});
|
||||
for (int k = 0; k < machine.sec; k++)
|
||||
for (int k = 0; k < n_ciphertexts; k++)
|
||||
{
|
||||
producer.ai = a[k];
|
||||
producer.bi = b[k];
|
||||
@@ -90,7 +93,7 @@ void PairwiseGenerator<FD>::run()
|
||||
else
|
||||
{
|
||||
timers["Plaintext conversion"].start();
|
||||
values.from_vec(producer.values[j].get_poly());
|
||||
values.from(producer.values[j].get_iterator());
|
||||
timers["Plaintext conversion"].stop();
|
||||
}
|
||||
|
||||
@@ -122,7 +125,7 @@ void PairwiseGenerator<FD>::generate_inputs(int player)
|
||||
{
|
||||
SeededPRNG G;
|
||||
b[0].randomize(G);
|
||||
b_mod_q.at(0).from_vec(b.at(0).get_poly());
|
||||
b_mod_q.at(0).from(b.at(0).get_iterator());
|
||||
producer.macs[0].mul(machine.setup<FD>().alpha, b[0]);
|
||||
}
|
||||
else
|
||||
|
||||
@@ -29,6 +29,8 @@ class PairwiseGenerator : public GeneratorBase
|
||||
MultiEncCommit<FD> EC;
|
||||
MAC_Check<T> MC;
|
||||
|
||||
int n_ciphertexts;
|
||||
|
||||
// temporary data
|
||||
AddableVector<Ciphertext> C;
|
||||
octetStream ciphertexts, cleartexts;
|
||||
|
||||
@@ -48,12 +48,21 @@ void PairwiseSetup<FD>::init(const Player& P, int sec, int plaintext_length,
|
||||
|
||||
template <class FD>
|
||||
void PairwiseSetup<FD>::secure_init(Player& P, PairwiseMachine& machine, int plaintext_length, int sec)
|
||||
{
|
||||
::secure_init(*this, P, machine, plaintext_length, sec);
|
||||
alpha = FieldD;
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void secure_init(T& setup, Player& P, MachineBase& machine,
|
||||
int plaintext_length, int sec)
|
||||
{
|
||||
machine.sec = sec;
|
||||
sec = max(sec, 40);
|
||||
machine.drown_sec = sec;
|
||||
string filename = PREP_DIR "Params-" + FD::T::type_string() + "-"
|
||||
+ to_string(plaintext_length) + "-" + to_string(sec) + "-P"
|
||||
string filename = PREP_DIR + T::name() + "-"
|
||||
+ to_string(plaintext_length) + "-" + to_string(sec) + "-"
|
||||
+ to_string(CowGearOptions::singleton.top_gear()) + "-P"
|
||||
+ to_string(P.my_num());
|
||||
try
|
||||
{
|
||||
@@ -61,38 +70,54 @@ void PairwiseSetup<FD>::secure_init(Player& P, PairwiseMachine& machine, int pla
|
||||
octetStream os;
|
||||
os.input(file);
|
||||
os.get(machine.extra_slack);
|
||||
params.unpack(os);
|
||||
FieldD.unpack(os);
|
||||
FieldD.init_field();
|
||||
check(P, machine);
|
||||
setup.unpack(os);
|
||||
setup.check(P, machine);
|
||||
}
|
||||
catch (...)
|
||||
{
|
||||
cout << "Finding parameters for security " << sec << " and field size ~2^"
|
||||
<< plaintext_length << endl;
|
||||
machine.extra_slack = generate_semi_setup(plaintext_length, sec, params, FieldD, true);
|
||||
check(P, machine);
|
||||
setup.generate(P, machine, plaintext_length, sec);
|
||||
setup.check(P, machine);
|
||||
octetStream os;
|
||||
os.store(machine.extra_slack);
|
||||
params.pack(os);
|
||||
FieldD.pack(os);
|
||||
setup.pack(os);
|
||||
ofstream file(filename);
|
||||
os.output(file);
|
||||
}
|
||||
alpha = FieldD;
|
||||
}
|
||||
|
||||
template <class FD>
|
||||
void PairwiseSetup<FD>::check(Player& P, PairwiseMachine& machine)
|
||||
void PairwiseSetup<FD>::generate(Player&, MachineBase& machine,
|
||||
int plaintext_length, int sec)
|
||||
{
|
||||
machine.extra_slack = generate_semi_setup(plaintext_length, sec, params,
|
||||
FieldD, true);
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
void PairwiseSetup<FD>::pack(octetStream& os) const
|
||||
{
|
||||
params.pack(os);
|
||||
FieldD.pack(os);
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
void PairwiseSetup<FD>::unpack(octetStream& os)
|
||||
{
|
||||
params.unpack(os);
|
||||
FieldD.unpack(os);
|
||||
FieldD.init_field();
|
||||
}
|
||||
|
||||
template <class FD>
|
||||
void PairwiseSetup<FD>::check(Player& P, MachineBase& machine)
|
||||
{
|
||||
Bundle<octetStream> bundle(P);
|
||||
bundle.mine.store(machine.extra_slack);
|
||||
params.pack(bundle.mine);
|
||||
FieldD.hash(bundle.mine);
|
||||
P.Broadcast_Receive(bundle, true);
|
||||
for (auto& os : bundle)
|
||||
if (os != bundle.mine)
|
||||
throw runtime_error("mismatch of parameters among parties");
|
||||
bundle.compare(P);
|
||||
}
|
||||
|
||||
template <class FD>
|
||||
@@ -161,3 +186,6 @@ void PairwiseSetup<FD>::set_alphai(T alphai)
|
||||
|
||||
template class PairwiseSetup<FFT_Data>;
|
||||
template class PairwiseSetup<P2Data>;
|
||||
|
||||
template void secure_init(PartSetup<FFT_Data>&, Player&, MachineBase&, int, int);
|
||||
template void secure_init(PartSetup<P2Data>&, Player&, MachineBase&, int, int);
|
||||
|
||||
@@ -11,6 +11,11 @@
|
||||
#include "Networking/Player.h"
|
||||
|
||||
class PairwiseMachine;
|
||||
class MachineBase;
|
||||
|
||||
template <class T>
|
||||
void secure_init(T& setup, Player& P, MachineBase& machine,
|
||||
int plaintext_length, int sec);
|
||||
|
||||
template <class FD>
|
||||
class PairwiseSetup
|
||||
@@ -24,15 +29,24 @@ public:
|
||||
Plaintext_<FD> alpha;
|
||||
string dirname;
|
||||
|
||||
static string name()
|
||||
{
|
||||
return "PairwiseParams-" + FD::T::type_string();
|
||||
}
|
||||
|
||||
PairwiseSetup() : params(0), alpha(FieldD) {}
|
||||
|
||||
void init(const Player& P, int sec, int plaintext_length, int& extra_slack);
|
||||
|
||||
void secure_init(Player& P, PairwiseMachine& machine, int plaintext_length, int sec);
|
||||
void check(Player& P, PairwiseMachine& machine);
|
||||
void generate(Player& P, MachineBase& machine, int plaintext_length, int sec);
|
||||
void check(Player& P, MachineBase& machine);
|
||||
void covert_key_generation(Player& P, PairwiseMachine& machine, int num_runs);
|
||||
void covert_mac_generation(Player& P, PairwiseMachine& machine, int num_runs);
|
||||
|
||||
void pack(octetStream& os) const;
|
||||
void unpack(octetStream& os);
|
||||
|
||||
void set_alphai(T alphai);
|
||||
};
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
#include "Sacrificing.h"
|
||||
#include "Reshare.h"
|
||||
#include "DistDecrypt.h"
|
||||
#include "SimpleEncCommit.h"
|
||||
#include "SimpleMachine.h"
|
||||
#include "Tools/mkpath.h"
|
||||
|
||||
template<class FD>
|
||||
@@ -610,7 +612,7 @@ InputProducer<FD>::~InputProducer()
|
||||
template<class FD>
|
||||
void InputProducer<FD>::run(const Player& P, const FHE_PK& pk,
|
||||
const Ciphertext& calpha, EncCommitBase_<FD>& EC, DistDecrypt<FD>& dd,
|
||||
const T& alphai)
|
||||
const T& alphai, int player)
|
||||
{
|
||||
(void)alphai;
|
||||
|
||||
@@ -620,43 +622,86 @@ void InputProducer<FD>::run(const Player& P, const FHE_PK& pk,
|
||||
G.ReSeed();
|
||||
|
||||
Ciphertext gama(params),dummyc(params),ca(params);
|
||||
vector<octetStream> oca(P.num_players());
|
||||
const FD& FieldD = dd.f.get_field();
|
||||
Plaintext<T,FD,S> a(FieldD),ai(FieldD),gai(FieldD);
|
||||
Random_Coins rc(params);
|
||||
|
||||
this->n_slots = FieldD.num_slots();
|
||||
|
||||
Share<T> Sh;
|
||||
|
||||
a.randomize(G);
|
||||
rc.generate(G);
|
||||
pk.encrypt(ca, a, rc);
|
||||
ca.pack(oca[P.my_num()]);
|
||||
P.Broadcast_Receive(oca);
|
||||
inputs.resize(P.num_players());
|
||||
|
||||
for (int j = 0; j < P.num_players(); j++)
|
||||
int min, max;
|
||||
if (player < 0)
|
||||
{
|
||||
ca.unpack(oca[j]);
|
||||
// Reshare the aj values
|
||||
dd.reshare(ai, ca, EC);
|
||||
min = 0;
|
||||
max = P.num_players();
|
||||
}
|
||||
else
|
||||
{
|
||||
min = player;
|
||||
max = player + 1;
|
||||
}
|
||||
|
||||
// Generate encrypted MAC values
|
||||
mul(gama, calpha, ca, pk);
|
||||
map<string, Timer> timers;
|
||||
assert(EC.machine);
|
||||
SimpleEncCommit_<FD> personal_EC(P, pk, FieldD, timers, *EC.machine, 0);
|
||||
octetStream ciphertexts, cleartexts;
|
||||
|
||||
// Get shares on the MACs
|
||||
dd.reshare(gai, gama, EC);
|
||||
|
||||
for (unsigned int i = 0; i < ai.num_slots(); i++)
|
||||
for (int j = min; j < max; j++)
|
||||
{
|
||||
AddableVector<Ciphertext> C;
|
||||
vector<Plaintext_<FD>> m(EC.machine->sec, FieldD);
|
||||
if (j == P.my_num())
|
||||
{
|
||||
Sh.set_share(ai.element(i));
|
||||
Sh.set_mac(gai.element(i));
|
||||
if (write_output)
|
||||
for (auto& x : m)
|
||||
x.randomize(G);
|
||||
personal_EC.generate_proof(C, m, ciphertexts, cleartexts);
|
||||
P.send_all(ciphertexts, true);
|
||||
P.send_all(cleartexts, true);
|
||||
}
|
||||
else
|
||||
{
|
||||
P.receive_player(j, ciphertexts, true);
|
||||
P.receive_player(j, cleartexts, true);
|
||||
C.resize(personal_EC.machine->sec, pk.get_params());
|
||||
Verifier<FD, fixint<GFP_MOD_SZ>>(personal_EC.proof).NIZKPoK(C, ciphertexts,
|
||||
cleartexts, pk, false, false);
|
||||
}
|
||||
|
||||
inputs[j].clear();
|
||||
|
||||
for (size_t i = 0; i < C.size(); i++)
|
||||
{
|
||||
auto& ca = C.at(i);
|
||||
auto& a = m.at(i);
|
||||
|
||||
// Reshare the aj values
|
||||
dd.reshare(ai, ca, EC);
|
||||
|
||||
// Generate encrypted MAC values
|
||||
mul(gama, calpha, ca, pk);
|
||||
|
||||
// Get shares on the MACs
|
||||
dd.reshare(gai, gama, EC);
|
||||
|
||||
for (unsigned int i = 0; i < ai.num_slots(); i++)
|
||||
{
|
||||
Sh.output(outf[j], false);
|
||||
if (j == P.my_num())
|
||||
Sh.set_share(ai.element(i));
|
||||
Sh.set_mac(gai.element(i));
|
||||
if (write_output)
|
||||
{
|
||||
a.element(i).output(outf[j], false);
|
||||
Sh.output(outf[j], false);
|
||||
if (j == P.my_num())
|
||||
{
|
||||
a.element(i).output(outf[j], false);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
inputs[j].push_back({Sh, {}});
|
||||
if (j == P.my_num())
|
||||
inputs[j].back().value = a.element(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -221,6 +221,8 @@ class InputProducer : public Producer<FD>
|
||||
bool write_output;
|
||||
|
||||
public:
|
||||
vector<vector<InputTuple<Share<T>>>> inputs;
|
||||
|
||||
InputProducer(const Player& P, int output_thread = 0, bool write_output = true,
|
||||
string dir = PREP_DIR);
|
||||
~InputProducer();
|
||||
@@ -228,7 +230,15 @@ public:
|
||||
string data_type() { return "Inputs"; }
|
||||
|
||||
void run(const Player& P, const FHE_PK& pk, const Ciphertext& calpha,
|
||||
EncCommitBase_<FD>& EC, DistDecrypt<FD>& dd, const T& alphai);
|
||||
EncCommitBase_<FD>& EC, DistDecrypt<FD>& dd, const T& alphai)
|
||||
{
|
||||
run(P, pk, calpha, EC, dd, alphai, -1);
|
||||
}
|
||||
|
||||
void run(const Player& P, const FHE_PK& pk, const Ciphertext& calpha,
|
||||
EncCommitBase_<FD>& EC, DistDecrypt<FD>& dd, const T& alphai,
|
||||
int player);
|
||||
|
||||
int sacrifice(const Player& P, MAC_Check<T>& MC);
|
||||
|
||||
// no ops
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
#include "Proof.h"
|
||||
#include "FHE/P2Data.h"
|
||||
#include "FHEOffline/EncCommit.h"
|
||||
#include "Math/Z2k.hpp"
|
||||
|
||||
double Proof::dist = 0;
|
||||
|
||||
@@ -32,22 +33,50 @@ bigint Proof::slack(int slack, int sec, int phim)
|
||||
}
|
||||
}
|
||||
|
||||
void Proof::get_challenge(vector<int>& e, const octetStream& ciphertexts) const
|
||||
void Proof::set_challenge(const octetStream& ciphertexts)
|
||||
{
|
||||
unsigned int i;
|
||||
bigint hashout = ciphertexts.check_sum();
|
||||
|
||||
for (i=0; i<sec; i++)
|
||||
{ e[i]=(hashout.get_ui()>>(i))&1; }
|
||||
octetStream hash = ciphertexts.hash();
|
||||
PRNG G;
|
||||
assert(hash.get_length() >= SEED_SIZE);
|
||||
G.SetSeed(hash.get_data());
|
||||
set_challenge(G);
|
||||
}
|
||||
|
||||
void Proof::set_challenge(PRNG& G)
|
||||
{
|
||||
unsigned int i;
|
||||
|
||||
if (top_gear)
|
||||
{
|
||||
W.resize(V, vector<int>(U));
|
||||
for (i = 0; i < V; i++)
|
||||
for (unsigned j = 0; j < U; j++)
|
||||
W[i][j] = G.get_uint(2 * phim) - 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
e.resize(sec);
|
||||
for (i = 0; i < sec; i++)
|
||||
{
|
||||
e[i] = G.get_bit();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Proof::generate_challenge(const Player& P)
|
||||
{
|
||||
GlobalPRNG G(P);
|
||||
set_challenge(G);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
class AbsoluteBoundChecker
|
||||
{
|
||||
bigint bound, neg_bound;
|
||||
T bound, neg_bound;
|
||||
|
||||
public:
|
||||
AbsoluteBoundChecker(bigint bound) : bound(bound), neg_bound(-bound) {}
|
||||
bool outside(const bigint& value, double& dist)
|
||||
AbsoluteBoundChecker(T bound) : bound(bound), neg_bound(-this->bound) {}
|
||||
bool outside(const T& value, double& dist)
|
||||
{
|
||||
(void)dist;
|
||||
#ifdef PRINT_MIN_DIST
|
||||
@@ -57,17 +86,17 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
template <class T>
|
||||
bool Proof::check_bounds(T& z, AddableMatrix<bigint>& t, int i) const
|
||||
template <class T, class X>
|
||||
bool Proof::check_bounds(T& z, X& t, int i) const
|
||||
{
|
||||
unsigned int j,k;
|
||||
|
||||
// Check Bound 1 and Bound 2
|
||||
AbsoluteBoundChecker plain_checker(plain_check * n_proofs);
|
||||
AbsoluteBoundChecker rand_checker(rand_check * n_proofs);
|
||||
AbsoluteBoundChecker<fixint<2>> plain_checker(plain_check * n_proofs);
|
||||
AbsoluteBoundChecker<fixint<2>> rand_checker(rand_check * n_proofs);
|
||||
for (j=0; j<phim; j++)
|
||||
{
|
||||
const bigint& te = z[j];
|
||||
auto& te = z[j];
|
||||
if (plain_checker.outside(te, dist))
|
||||
{
|
||||
cout << "Fail on Check 1 " << i << " " << j << endl;
|
||||
@@ -78,10 +107,10 @@ bool Proof::check_bounds(T& z, AddableMatrix<bigint>& t, int i) const
|
||||
}
|
||||
for (k=0; k<3; k++)
|
||||
{
|
||||
const vector<bigint>& coeffs = t[k];
|
||||
auto& coeffs = t[k];
|
||||
for (j=0; j<coeffs.size(); j++)
|
||||
{
|
||||
const bigint& te = coeffs.at(j);
|
||||
auto& te = coeffs.at(j);
|
||||
if (rand_checker.outside(te, dist))
|
||||
{
|
||||
cout << "Fail on Check 2 " << k << " : " << i << " " << j << endl;
|
||||
@@ -101,7 +130,7 @@ Proof::Preimages::Preimages(int size, const FHE_PK& pk, const bigint& p, int n_p
|
||||
// extra limb for addition
|
||||
bigint limit = p << (64 + n_players);
|
||||
m.allocate_slots(limit);
|
||||
r.allocate_slots(limit);
|
||||
r.allocate_slots(n_players);
|
||||
m_tmp = m[0][0];
|
||||
r_tmp = r[0][0];
|
||||
}
|
||||
@@ -136,8 +165,8 @@ void Proof::Preimages::unpack(octetStream& os)
|
||||
unsigned int size;
|
||||
os.get(size);
|
||||
m.resize(size);
|
||||
if (size != r.size())
|
||||
throw runtime_error("unexpected size of preimage randomness");
|
||||
assert(not r.empty());
|
||||
r.resize(size, r[0]);
|
||||
for (size_t i = 0; i < m.size(); i++)
|
||||
{
|
||||
m[i].unpack(os);
|
||||
@@ -151,7 +180,6 @@ void Proof::Preimages::check_sizes()
|
||||
throw runtime_error("preimage sizes don't match");
|
||||
}
|
||||
|
||||
template bool Proof::check_bounds(Plaintext<gfp,FFT_Data,bigint>& z, AddableMatrix<bigint>& t, int i) const;
|
||||
template bool Proof::check_bounds(AddableVector<bigint>& z, AddableMatrix<bigint>& t, int i) const;
|
||||
|
||||
template bool Proof::check_bounds(Plaintext_<P2Data>& z, AddableMatrix<bigint>& t, int i) const;
|
||||
template bool Proof::check_bounds(AddableVector<fixint<2>>& z, AddableMatrix<fixint<0>>& t, int i) const;
|
||||
template bool Proof::check_bounds(AddableVector<fixint<2>>& z, AddableMatrix<fixint<1>>& t, int i) const;
|
||||
template bool Proof::check_bounds(AddableVector<fixint<2>>& z, AddableMatrix<fixint<2>>& t, int i) const;
|
||||
|
||||
@@ -8,6 +8,7 @@ using namespace std;
|
||||
#include "Math/bigint.h"
|
||||
#include "FHE/Ciphertext.h"
|
||||
#include "FHE/AddableVector.h"
|
||||
#include "Protocols/CowGearOptions.h"
|
||||
|
||||
#include "config.h"
|
||||
|
||||
@@ -21,6 +22,8 @@ enum SlackType
|
||||
|
||||
class Proof
|
||||
{
|
||||
unsigned int sec;
|
||||
|
||||
Proof(); // Private to avoid default
|
||||
|
||||
public:
|
||||
@@ -29,12 +32,14 @@ class Proof
|
||||
|
||||
class Preimages
|
||||
{
|
||||
bigint m_tmp;
|
||||
AddableVector<bigint> r_tmp;
|
||||
typedef Int_Random_Coins::value_type::value_type r_type;
|
||||
|
||||
fixint<GFP_MOD_SZ> m_tmp;
|
||||
AddableVector<r_type> r_tmp;
|
||||
|
||||
public:
|
||||
Preimages(int size, const FHE_PK& pk, const bigint& p, int n_players);
|
||||
AddableMatrix<bigint> m;
|
||||
AddableMatrix<fixint<GFP_MOD_SZ>> m;
|
||||
Randomness r;
|
||||
void add(octetStream& os);
|
||||
void pack(octetStream& os);
|
||||
@@ -43,18 +48,22 @@ class Proof
|
||||
size_t report_size(ReportType type) { return m.report_size(type) + r.report_size(type); }
|
||||
};
|
||||
|
||||
unsigned int sec;
|
||||
bigint tau,rho;
|
||||
|
||||
unsigned int phim;
|
||||
int B_plain_length, B_rand_length;
|
||||
bigint plain_check, rand_check;
|
||||
unsigned int V;
|
||||
unsigned int U;
|
||||
|
||||
const FHE_PK* pk;
|
||||
|
||||
int n_proofs;
|
||||
|
||||
vector<int> e;
|
||||
vector<vector<int>> W;
|
||||
bool top_gear;
|
||||
|
||||
static double dist;
|
||||
|
||||
protected:
|
||||
@@ -65,19 +74,67 @@ class Proof
|
||||
tau=Tau; rho=Rho;
|
||||
|
||||
phim=(pk.get_params()).phi_m();
|
||||
V=2*sec-1;
|
||||
|
||||
top_gear = use_top_gear(pk);
|
||||
if (top_gear)
|
||||
{
|
||||
V = ceil((sec + 2) / log2(2 * phim + 1));
|
||||
U = 2 * V;
|
||||
#ifdef VERBOSE
|
||||
cerr << "Using " << U << " ciphertexts per proof" << endl;
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
U = sec;
|
||||
V = 2 * sec - 1;
|
||||
}
|
||||
}
|
||||
|
||||
Proof(int sec, const FHE_PK& pk, int n_proofs = 1) :
|
||||
Proof(sec, pk.p() / 2, 2 * 3.2 * sqrt(pk.get_params().phi_m()), pk,
|
||||
Proof(sec, pk.p() / 2,
|
||||
pk.get_params().get_DG().get_NewHopeB(), pk,
|
||||
n_proofs) {}
|
||||
|
||||
virtual ~Proof() {}
|
||||
|
||||
public:
|
||||
static bigint slack(int slack, int sec, int phim);
|
||||
|
||||
void get_challenge(vector<int>& e, const octetStream& ciphertexts) const;
|
||||
template <class T>
|
||||
bool check_bounds(T& z, AddableMatrix<bigint>& t, int i) const;
|
||||
static bool use_top_gear(const FHE_PK& pk)
|
||||
{
|
||||
return CowGearOptions::singleton.top_gear() and pk.p() > 2;
|
||||
}
|
||||
|
||||
static int n_ciphertext_per_proof(int sec, const FHE_PK& pk)
|
||||
{
|
||||
return Proof(sec, pk, 1).U;
|
||||
}
|
||||
|
||||
void set_challenge(const octetStream& ciphertexts);
|
||||
void set_challenge(PRNG& G);
|
||||
void generate_challenge(const Player& P);
|
||||
|
||||
template <class T, class X>
|
||||
bool check_bounds(T& z, X& t, int i) const;
|
||||
|
||||
template<class T, class U>
|
||||
void apply_challenge(int i, T& output, const U& input, const FHE_PK& pk) const
|
||||
{
|
||||
if (top_gear)
|
||||
{
|
||||
for (unsigned j = 0; j < this->U; j++)
|
||||
if (W[i][j] >= 0)
|
||||
output += input[j].mul_by_X_i(W[i][j], pk);
|
||||
}
|
||||
else
|
||||
for (unsigned k = 0; k < sec; k++)
|
||||
{
|
||||
unsigned j = (i + 1) - (k + 1);
|
||||
if (j < sec && e.at(j))
|
||||
output += input.at(j);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
class NonInteractiveProof : public Proof
|
||||
@@ -111,8 +168,7 @@ public:
|
||||
Proof(sec, pk, n_proofs)
|
||||
{
|
||||
bigint B;
|
||||
// using mu = 1
|
||||
B = bigint(1) << (sec - 1);
|
||||
B = bigint(1) << sec;
|
||||
B_plain_length = numBits(B * tau);
|
||||
B_rand_length = numBits(B * rho);
|
||||
// leeway for completeness
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "FHE/P2Data.h"
|
||||
#include "Tools/random.h"
|
||||
#include "Math/Z2k.hpp"
|
||||
|
||||
|
||||
template <class FD, class U>
|
||||
@@ -62,35 +63,25 @@ template <class FD, class U>
|
||||
bool Prover<FD,U>::Stage_2(Proof& P, octetStream& cleartexts,
|
||||
const vector<U>& x,
|
||||
const Proof::Randomness& r,
|
||||
const vector<int>& e)
|
||||
const FHE_PK& pk)
|
||||
{
|
||||
size_t allocate = P.V * P.phim
|
||||
* (5 + numBytes(P.plain_check) + 3 * (5 + numBytes(P.rand_check)));
|
||||
cleartexts.resize_precise(allocate);
|
||||
cleartexts.reset_write_head();
|
||||
|
||||
unsigned int i,k;
|
||||
int j,ee;
|
||||
unsigned int i;
|
||||
#ifndef LESS_ALLOC_MORE_MEM
|
||||
AddableVector<bigint> z;
|
||||
AddableMatrix<bigint> t;
|
||||
AddableVector<fixint<GFP_MOD_SZ>> z;
|
||||
AddableMatrix<fixint<GFP_MOD_SZ>> t;
|
||||
#endif
|
||||
cleartexts.reset_write_head();
|
||||
cleartexts.store(P.V);
|
||||
for (i=0; i<P.V; i++)
|
||||
{ z=y[i];
|
||||
t=s[i];
|
||||
for (k=0; k<P.sec; k++)
|
||||
{ j=(i+1)-(k+1);
|
||||
if (j<0 || j>=(int) P.sec) { ee=0; }
|
||||
else { ee=e[j]; }
|
||||
|
||||
if (ee!=0)
|
||||
{
|
||||
z += x[j];
|
||||
t += r[j];
|
||||
}
|
||||
}
|
||||
P.apply_challenge(i, z, x, pk);
|
||||
P.apply_challenge(i, t, r, pk);
|
||||
if (not P.check_bounds(z, t, i))
|
||||
return false;
|
||||
z.pack(cleartexts);
|
||||
@@ -118,8 +109,6 @@ size_t Prover<FD,U>::NIZKPoK(Proof& P, octetStream& ciphertexts, octetStream& cl
|
||||
const Proof::Randomness& r,
|
||||
bool Diag,bool binary)
|
||||
{
|
||||
vector<int> e(P.sec);
|
||||
|
||||
// AElement<T> AE;
|
||||
// for (i=0; i<P.sec; i++)
|
||||
// { AE.assign(x.at(i));
|
||||
@@ -134,9 +123,9 @@ size_t Prover<FD,U>::NIZKPoK(Proof& P, octetStream& ciphertexts, octetStream& cl
|
||||
while (!ok)
|
||||
{ cnt++;
|
||||
Stage_1(P,ciphertexts,c,pk,Diag,binary);
|
||||
P.get_challenge(e, ciphertexts);
|
||||
P.set_challenge(ciphertexts);
|
||||
// Check check whether we are OK, or whether we should abort
|
||||
ok = Stage_2(P,cleartexts,x,r,e);
|
||||
ok = Stage_2(P,cleartexts,x,r,pk);
|
||||
}
|
||||
if (cnt > 1)
|
||||
cout << "\t\tNumber iterations of prover = " << cnt << endl;
|
||||
@@ -173,7 +162,7 @@ void Prover<FD, U>::report_size(ReportType type, MemoryUsage& res)
|
||||
|
||||
|
||||
template class Prover<FFT_Data, Plaintext_<FFT_Data> >;
|
||||
template class Prover<FFT_Data, AddableVector<bigint> >;
|
||||
//template class Prover<FFT_Data, AddableVector<bigint> >;
|
||||
|
||||
template class Prover<P2Data, Plaintext_<P2Data> >;
|
||||
template class Prover<P2Data, AddableVector<bigint> >;
|
||||
//template class Prover<P2Data, AddableVector<bigint> >;
|
||||
|
||||
@@ -14,8 +14,8 @@ class Prover
|
||||
AddableVector< Plaintext_<FD> > y;
|
||||
|
||||
#ifdef LESS_ALLOC_MORE_MEM
|
||||
AddableVector<bigint> z;
|
||||
AddableMatrix<bigint> t;
|
||||
AddableVector<fixint<GFP_MOD_SZ>> z;
|
||||
AddableMatrix<Int_Random_Coins::value_type::value_type> t;
|
||||
#endif
|
||||
|
||||
public:
|
||||
@@ -30,7 +30,7 @@ public:
|
||||
bool Stage_2(Proof& P, octetStream& cleartexts,
|
||||
const vector<U>& x,
|
||||
const Proof::Randomness& r,
|
||||
const vector<int>& e);
|
||||
const FHE_PK& pk);
|
||||
|
||||
/* Only has a non-interactive version using the ROM
|
||||
- If Diag is true then the plaintexts x are assumed to be
|
||||
|
||||
@@ -85,6 +85,7 @@ void Triple_Checking(const Player& P, MAC_Check<T>& MC, int nm,
|
||||
|
||||
// Triple checking
|
||||
int left_todo=nm;
|
||||
factory.triples.clear();
|
||||
while (left_todo>0)
|
||||
{ int this_loop=amortize;
|
||||
if (this_loop>left_todo)
|
||||
@@ -121,7 +122,17 @@ void Triple_Checking(const Player& P, MAC_Check<T>& MC, int nm,
|
||||
|
||||
for (int i=0; i<this_loop; i++)
|
||||
{ if (!Tau[i].is_zero())
|
||||
{ throw Offline_Check_Error("Multiplication Triples"); }
|
||||
{
|
||||
MC.POpen(PO, {a1[i], b1[i], c1[i], a2[i], b2[i], c2[i]}, P);
|
||||
for (int j = 0; j < 2; j++)
|
||||
{
|
||||
auto prod = PO[3 * j + 1] * PO[3 * j];
|
||||
if (PO[3 * j + 2] != prod)
|
||||
cout << PO[3 * j + 2] << " != " << prod << " = "
|
||||
<< PO[3 * j + 1] << " * " << PO[3 * j] << endl;
|
||||
}
|
||||
throw Offline_Check_Error("Multiplication Triples");
|
||||
}
|
||||
if (write_output)
|
||||
{
|
||||
a1[i].output(outf,false);
|
||||
@@ -326,6 +337,7 @@ void Square_Checking(const Player& P, MAC_Check<T>& MC, int ns,
|
||||
|
||||
// Do the square checking
|
||||
int left_todo=ns;
|
||||
square_factory.tuples.clear();
|
||||
while (left_todo>0)
|
||||
{ int this_loop=amortize;
|
||||
if (this_loop>left_todo)
|
||||
@@ -363,6 +375,8 @@ void Square_Checking(const Player& P, MAC_Check<T>& MC, int ns,
|
||||
{
|
||||
a[i].output(outf_s,false); b[i].output(outf_s,false);
|
||||
}
|
||||
else
|
||||
square_factory.tuples.push_back({{a[i], b[i]}});
|
||||
}
|
||||
left_todo-=this_loop;
|
||||
}
|
||||
|
||||
@@ -25,6 +25,8 @@ template <class T>
|
||||
class TupleSacriFactory
|
||||
{
|
||||
public:
|
||||
vector<array<T, 2>> tuples;
|
||||
|
||||
virtual ~TupleSacriFactory() {}
|
||||
virtual void get(T& a, T& b) = 0;
|
||||
};
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
|
||||
template<class T, class FD, class S>
|
||||
SimpleEncCommitBase<T, FD, S>::SimpleEncCommitBase(const MachineBase& machine) :
|
||||
sec(machine.sec), extra_slack(machine.extra_slack), n_rounds(0)
|
||||
EncCommitBase_<FD>(&machine),
|
||||
extra_slack(machine.extra_slack), n_rounds(0)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -36,7 +37,7 @@ NonInteractiveProofSimpleEncCommit<FD>::NonInteractiveProofSimpleEncCommit(
|
||||
P(P), pk(pk), FTD(FTD),
|
||||
proof(machine.sec, pk, machine.extra_slack),
|
||||
#ifdef LESS_ALLOC_MORE_MEM
|
||||
r(this->sec, this->pk.get_params()), prover(proof, FTD),
|
||||
r(proof.U, this->pk.get_params()), prover(proof, FTD),
|
||||
verifier(proof),
|
||||
#endif
|
||||
timers(timers)
|
||||
@@ -46,9 +47,9 @@ NonInteractiveProofSimpleEncCommit<FD>::NonInteractiveProofSimpleEncCommit(
|
||||
template <class FD>
|
||||
SimpleEncCommitFactory<FD>::SimpleEncCommitFactory(const FHE_PK& pk,
|
||||
const FD& FTD, const MachineBase& machine) :
|
||||
cnt(-1), n_calls(0)
|
||||
cnt(-1), n_calls(0), pk(pk)
|
||||
{
|
||||
int sec = machine.sec;
|
||||
int sec = Proof::n_ciphertext_per_proof(machine.sec, pk);
|
||||
c.resize(sec, pk.get_params());
|
||||
m.resize(sec, FTD);
|
||||
for (int i = 0; i < sec; i++)
|
||||
@@ -70,6 +71,13 @@ void SimpleEncCommitFactory<FD>::next(Plaintext_<FD>& mess, Ciphertext& C)
|
||||
create_more();
|
||||
mess = m[cnt];
|
||||
C = c[cnt];
|
||||
|
||||
if (Proof::use_top_gear(pk))
|
||||
{
|
||||
mess = mess + mess;
|
||||
C = C + C;
|
||||
}
|
||||
|
||||
cnt--;
|
||||
n_calls++;
|
||||
}
|
||||
@@ -84,18 +92,21 @@ void SimpleEncCommitFactory<FD>::prepare_plaintext(PRNG& G)
|
||||
template<class T,class FD,class S>
|
||||
void SimpleEncCommitBase<T, FD, S>::generate_ciphertexts(
|
||||
AddableVector<Ciphertext>& c, const vector<Plaintext_<FD> >& m,
|
||||
Proof::Randomness& r, const FHE_PK& pk, TimerMap& timers)
|
||||
Proof::Randomness& r, const FHE_PK& pk, TimerMap& timers,
|
||||
Proof& proof)
|
||||
{
|
||||
timers["Generating"].start();
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
prepare_plaintext(G);
|
||||
Random_Coins rc(pk.get_params());
|
||||
for (int i = 0; i < sec; i++)
|
||||
c.resize(proof.U, pk);
|
||||
r.resize(proof.U, pk);
|
||||
for (unsigned i = 0; i < proof.U; i++)
|
||||
{
|
||||
r[i].sample(G);
|
||||
rc.assign(r[i]);
|
||||
pk.encrypt(c[i], m[i], rc);
|
||||
pk.encrypt(c[i], m.at(i), rc);
|
||||
}
|
||||
timers["Generating"].stop();
|
||||
memory_usage.update("random coins", rc.report_size(CAPACITY));
|
||||
@@ -103,14 +114,14 @@ void SimpleEncCommitBase<T, FD, S>::generate_ciphertexts(
|
||||
|
||||
template <class FD>
|
||||
size_t NonInteractiveProofSimpleEncCommit<FD>::generate_proof(AddableVector<Ciphertext>& c,
|
||||
const vector<Plaintext_<FD> >& m, octetStream& ciphertexts,
|
||||
vector<Plaintext_<FD> >& m, octetStream& ciphertexts,
|
||||
octetStream& cleartexts)
|
||||
{
|
||||
timers["Proving"].start();
|
||||
#ifndef LESS_ALLOC_MORE_MEM
|
||||
Proof::Randomness r(this->sec, pk.get_params());
|
||||
Proof::Randomness r(proof.U, pk.get_params());
|
||||
#endif
|
||||
this->generate_ciphertexts(c, m, r, pk, timers);
|
||||
this->generate_ciphertexts(c, m, r, pk, timers, proof);
|
||||
#ifndef LESS_ALLOC_MORE_MEM
|
||||
Prover<FD, Plaintext_<FD> > prover(proof, FTD);
|
||||
#endif
|
||||
@@ -118,6 +129,13 @@ size_t NonInteractiveProofSimpleEncCommit<FD>::generate_proof(AddableVector<Ciph
|
||||
pk, c, m, r, false, false);
|
||||
timers["Proving"].stop();
|
||||
|
||||
if (proof.top_gear)
|
||||
{
|
||||
c += c;
|
||||
for (auto& mm : m)
|
||||
mm += mm;
|
||||
}
|
||||
|
||||
// cout << "Checking my own proof" << endl;
|
||||
// if (!Verifier<gfp>().NIZKPoK(c[P.my_num()], proofs[P.my_num()], pk, false, false))
|
||||
// throw runtime_error("proof check failed");
|
||||
@@ -137,7 +155,7 @@ void SimpleEncCommit<T,FD,S>::create_more()
|
||||
cleartexts);
|
||||
cout << "Done checking proofs in round " << this->n_rounds << endl;
|
||||
this->n_rounds++;
|
||||
this->cnt = this->sec - 1;
|
||||
this->cnt = this->proof.U - 1;
|
||||
this->memory_usage.update("serialized ciphertexts",
|
||||
ciphertexts.get_max_length());
|
||||
this->memory_usage.update("serialized cleartexts", cleartexts.get_max_length());
|
||||
@@ -150,7 +168,7 @@ size_t NonInteractiveProofSimpleEncCommit<FD>::create_more(octetStream& cipherte
|
||||
octetStream& cleartexts)
|
||||
{
|
||||
AddableVector<Ciphertext> others_ciphertexts;
|
||||
others_ciphertexts.resize(this->sec, pk.get_params());
|
||||
others_ciphertexts.resize(proof.U, pk.get_params());
|
||||
for (int i = 1; i < P.num_players(); i++)
|
||||
{
|
||||
#ifdef VERBOSE_HE
|
||||
@@ -189,18 +207,31 @@ void SimpleEncCommit<T, FD, S>::add_ciphertexts(
|
||||
vector<Ciphertext>& ciphertexts, int offset)
|
||||
{
|
||||
(void)offset;
|
||||
for (int j = 0; j < this->sec; j++)
|
||||
for (unsigned j = 0; j < this->proof.U; j++)
|
||||
add(this->c[j], this->c[j], ciphertexts[j]);
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
SummingEncCommit<FD>::SummingEncCommit(const Player& P, const FHE_PK& pk,
|
||||
const FD& FTD, map<string, Timer>& timers, const MachineBase& machine,
|
||||
int thread_num) :
|
||||
SimpleEncCommitFactory<FD>(pk, FTD, machine), SimpleEncCommitBase_<FD>(
|
||||
machine), proof(machine.sec, pk, P.num_players()), pk(pk), FTD(
|
||||
FTD), P(P), thread_num(thread_num),
|
||||
#ifdef LESS_ALLOC_MORE_MEM
|
||||
prover(proof, FTD), verifier(proof), preimages(proof.V,
|
||||
this->pk, FTD.get_prime(), P.num_players()),
|
||||
#endif
|
||||
timers(timers)
|
||||
{
|
||||
}
|
||||
|
||||
template<class FD>
|
||||
void SummingEncCommit<FD>::create_more()
|
||||
{
|
||||
octetStream cleartexts;
|
||||
const Player& P = this->P;
|
||||
InteractiveProof proof(this->sec, this->pk, P.num_players());
|
||||
AddableVector<Ciphertext> commitments;
|
||||
vector<int> e(this->sec);
|
||||
size_t prover_size;
|
||||
MemoryUsage& memory_usage = this->memory_usage;
|
||||
TreeSum<Ciphertext> tree_sum(2, 2, thread_num % P.num_players());
|
||||
@@ -210,10 +241,10 @@ void SummingEncCommit<FD>::create_more()
|
||||
#ifdef LESS_ALLOC_MORE_MEM
|
||||
Proof::Randomness& r = preimages.r;
|
||||
#else
|
||||
Proof::Randomness r(this->sec, this->pk.get_params());
|
||||
Proof::Randomness r(proof.U, this->pk.get_params());
|
||||
Prover<FD, Plaintext_<FD> > prover(proof, this->FTD);
|
||||
#endif
|
||||
this->generate_ciphertexts(this->c, this->m, r, pk, timers);
|
||||
this->generate_ciphertexts(this->c, this->m, r, pk, timers, proof);
|
||||
this->timers["Stage 1 of proof"].start();
|
||||
prover.Stage_1(proof, ciphertexts, this->c, this->pk, false, false);
|
||||
this->timers["Stage 1 of proof"].stop();
|
||||
@@ -228,10 +259,10 @@ void SummingEncCommit<FD>::create_more()
|
||||
tree_sum.run(commitments, P);
|
||||
this->timers["Exchanging ciphertexts"].stop();
|
||||
|
||||
generate_challenge(e, P);
|
||||
proof.generate_challenge(P);
|
||||
|
||||
this->timers["Stage 2 of proof"].start();
|
||||
prover.Stage_2(proof, cleartexts, this->m, r, e);
|
||||
prover.Stage_2(proof, cleartexts, this->m, r, pk);
|
||||
this->timers["Stage 2 of proof"].stop();
|
||||
|
||||
prover_size = prover.report_size(CAPACITY) + r.report_size(CAPACITY)
|
||||
@@ -273,10 +304,10 @@ void SummingEncCommit<FD>::create_more()
|
||||
#else
|
||||
Verifier<FD,S> verifier(proof);
|
||||
#endif
|
||||
verifier.Stage_2(e, this->c, ciphertexts, cleartexts,
|
||||
verifier.Stage_2(this->c, ciphertexts, cleartexts,
|
||||
this->pk, false, false);
|
||||
this->timers["Verifying"].stop();
|
||||
this->cnt = this->sec - 1;
|
||||
this->cnt = proof.U - 1;
|
||||
|
||||
this->volatile_memory =
|
||||
+ commitments.report_size(CAPACITY) + ciphertexts.get_max_length()
|
||||
@@ -339,7 +370,7 @@ template <class FD>
|
||||
void MultiEncCommit<FD>::add_ciphertexts(vector<Ciphertext>& ciphertexts,
|
||||
int offset)
|
||||
{
|
||||
for (int i = 0; i < this->sec; i++)
|
||||
for (unsigned i = 0; i < this->proof.U; i++)
|
||||
generator.multipliers[offset - 1]->multiply_and_add(generator.c.at(i),
|
||||
ciphertexts.at(i), generator.b_mod_q.at(i));
|
||||
}
|
||||
|
||||
@@ -20,14 +20,13 @@ template<class T,class FD,class S>
|
||||
class SimpleEncCommitBase : public EncCommitBase<T,FD,S>
|
||||
{
|
||||
protected:
|
||||
int sec;
|
||||
int extra_slack;
|
||||
|
||||
int n_rounds;
|
||||
|
||||
void generate_ciphertexts(AddableVector<Ciphertext>& c,
|
||||
const vector<Plaintext_<FD> >& m, Proof::Randomness& r,
|
||||
const FHE_PK& pk, map<string, Timer>& timers);
|
||||
const FHE_PK& pk, map<string, Timer>& timers, Proof& proof);
|
||||
|
||||
virtual void prepare_plaintext(PRNG& G) = 0;
|
||||
|
||||
@@ -46,12 +45,16 @@ template <class FD>
|
||||
class NonInteractiveProofSimpleEncCommit : public SimpleEncCommitBase_<FD>
|
||||
{
|
||||
protected:
|
||||
typedef bigint S;
|
||||
typedef fixint<GFP_MOD_SZ> S;
|
||||
|
||||
const PlayerBase& P;
|
||||
const FHE_PK& pk;
|
||||
const FD& FTD;
|
||||
|
||||
virtual const FHE_PK& get_pk_for_verification(int offset) = 0;
|
||||
virtual void add_ciphertexts(vector<Ciphertext>& ciphertexts, int offset) = 0;
|
||||
|
||||
public:
|
||||
NonInteractiveProof proof;
|
||||
|
||||
#ifdef LESS_ALLOC_MORE_MEM
|
||||
@@ -60,17 +63,13 @@ protected:
|
||||
Verifier<FD,S> verifier;
|
||||
#endif
|
||||
|
||||
virtual const FHE_PK& get_pk_for_verification(int offset) = 0;
|
||||
virtual void add_ciphertexts(vector<Ciphertext>& ciphertexts, int offset) = 0;
|
||||
|
||||
public:
|
||||
map<string, Timer>& timers;
|
||||
|
||||
NonInteractiveProofSimpleEncCommit(const PlayerBase& P, const FHE_PK& pk,
|
||||
const FD& FTD, map<string, Timer>& timers,
|
||||
const MachineBase& machine);
|
||||
virtual ~NonInteractiveProofSimpleEncCommit() {}
|
||||
size_t generate_proof(AddableVector<Ciphertext>& c, const vector<Plaintext_<FD> >& m,
|
||||
size_t generate_proof(AddableVector<Ciphertext>& c, vector<Plaintext_<FD> >& m,
|
||||
octetStream& ciphertexts, octetStream& cleartexts);
|
||||
size_t create_more(octetStream& my_ciphertext, octetStream& my_cleartext);
|
||||
virtual size_t report_size(ReportType type);
|
||||
@@ -87,6 +86,8 @@ protected:
|
||||
|
||||
int n_calls;
|
||||
|
||||
const FHE_PK& pk;
|
||||
|
||||
void prepare_plaintext(PRNG& G);
|
||||
virtual void create_more() = 0;
|
||||
|
||||
@@ -104,7 +105,8 @@ class SimpleEncCommit: public NonInteractiveProofSimpleEncCommit<FD>,
|
||||
public SimpleEncCommitFactory<FD>
|
||||
{
|
||||
protected:
|
||||
const FHE_PK& get_pk_for_verification(int offset) { (void)offset; return this->pk; }
|
||||
const FHE_PK& get_pk_for_verification(int)
|
||||
{ return NonInteractiveProofSimpleEncCommit<FD>::pk; }
|
||||
void prepare_plaintext(PRNG& G)
|
||||
{ SimpleEncCommitFactory<FD>::prepare_plaintext(G); }
|
||||
void add_ciphertexts(vector<Ciphertext>& ciphertexts, int offset);
|
||||
@@ -127,7 +129,7 @@ template <class FD>
|
||||
class SummingEncCommit: public SimpleEncCommitFactory<FD>,
|
||||
public SimpleEncCommitBase_<FD>
|
||||
{
|
||||
typedef bigint S;
|
||||
typedef fixint<GFP_MOD_SZ> S;
|
||||
|
||||
InteractiveProof proof;
|
||||
const FHE_PK& pk;
|
||||
@@ -148,15 +150,8 @@ public:
|
||||
map<string, Timer>& timers;
|
||||
|
||||
SummingEncCommit(const Player& P, const FHE_PK& pk, const FD& FTD,
|
||||
map<string, Timer>& timers, const MachineBase& machine, int thread_num) :
|
||||
SimpleEncCommitFactory<FD>(pk, FTD, machine), SimpleEncCommitBase_<FD>(machine),
|
||||
proof(this->sec, pk, P.num_players()), pk(pk), FTD(FTD), P(P),
|
||||
thread_num(thread_num),
|
||||
#ifdef LESS_ALLOC_MORE_MEM
|
||||
prover(proof, FTD), verifier(proof), preimages(proof.V, this->pk,
|
||||
FTD.get_prime(), P.num_players()),
|
||||
#endif
|
||||
timers(timers) {}
|
||||
map<string, Timer>& timers, const MachineBase& machine, int thread_num);
|
||||
|
||||
void next(Plaintext_<FD>& mess, Ciphertext& C) { SimpleEncCommitFactory<FD>::next(mess, C); }
|
||||
void create_more();
|
||||
size_t report_size(ReportType type);
|
||||
|
||||
@@ -12,10 +12,11 @@
|
||||
|
||||
template <template <class> class T, class FD>
|
||||
SimpleGenerator<T,FD>::SimpleGenerator(const Names& N, const PartSetup<FD>& setup,
|
||||
const MultiplicativeMachine& machine, int thread_num, Dtype data_type) :
|
||||
GeneratorBase(thread_num, N),
|
||||
setup(setup), machine(machine), dd(P, setup),
|
||||
volatile_memory(0),
|
||||
const MultiplicativeMachine& machine,
|
||||
int thread_num, Dtype data_type, Player* player) :
|
||||
GeneratorBase(thread_num, N, player),
|
||||
setup(setup), machine(machine),
|
||||
volatile_memory(0), dd(P, setup),
|
||||
EC(P, setup.pk, setup.FieldD, timers, machine, thread_num)
|
||||
{
|
||||
if (machine.produce_inputs)
|
||||
@@ -51,14 +52,14 @@ SimpleGenerator<T,FD>::~SimpleGenerator()
|
||||
}
|
||||
|
||||
template <template <class> class T, class FD>
|
||||
void SimpleGenerator<T,FD>::run()
|
||||
void SimpleGenerator<T,FD>::run(bool exhaust)
|
||||
{
|
||||
Timer timer(CLOCK_THREAD_CPUTIME_ID);
|
||||
timer.start();
|
||||
timers["MC init"].start();
|
||||
MAC_Check<typename FD::T> MC(setup.alphai);
|
||||
timers["MC init"].stop();
|
||||
while (total < machine.nTriplesPerThread or EC.has_left())
|
||||
while (total < machine.nTriplesPerThread or (exhaust and EC.has_left()))
|
||||
{
|
||||
producer->run(P, setup.pk, setup.calpha, EC, dd, setup.alphai);
|
||||
producer->sacrifice(P, MC);
|
||||
|
||||
@@ -55,20 +55,20 @@ class SimpleGenerator : public GeneratorBase
|
||||
const PartSetup<FD>& setup;
|
||||
const MultiplicativeMachine& machine;
|
||||
|
||||
SimpleDistDecrypt<FD> dd;
|
||||
|
||||
size_t volatile_memory;
|
||||
|
||||
public:
|
||||
SimpleDistDecrypt<FD> dd;
|
||||
T<FD> EC;
|
||||
Producer<FD>* producer;
|
||||
|
||||
SimpleGenerator(const Names& N, const PartSetup<FD>& setup,
|
||||
const MultiplicativeMachine& machine, int thread_num,
|
||||
Dtype data_type = DATA_TRIPLE);
|
||||
Dtype data_type = DATA_TRIPLE, Player* player = 0);
|
||||
~SimpleGenerator();
|
||||
|
||||
void run();
|
||||
void run() { run(true); }
|
||||
void run(bool exhaust);
|
||||
size_t report_size(ReportType type);
|
||||
void report_size(ReportType type, MemoryUsage& res);
|
||||
size_t report_sent() { return P.sent; }
|
||||
|
||||
@@ -53,8 +53,6 @@ public:
|
||||
class MultiplicativeMachine : public MachineBase
|
||||
{
|
||||
protected:
|
||||
DataSetup setup;
|
||||
|
||||
void parse_options(int argc, const char** argv);
|
||||
|
||||
void generate_setup(int slack);
|
||||
@@ -62,6 +60,8 @@ protected:
|
||||
void fake_keys(int slack);
|
||||
|
||||
public:
|
||||
DataSetup setup;
|
||||
|
||||
virtual ~MultiplicativeMachine() {}
|
||||
|
||||
virtual int get_covert() const { return 0; }
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#include "Verifier.h"
|
||||
#include "Math/Z2k.hpp"
|
||||
|
||||
template <class FD, class S>
|
||||
Verifier<FD,S>::Verifier(const Proof& proof) : P(proof)
|
||||
Verifier<FD,S>::Verifier(Proof& proof) : P(proof)
|
||||
{
|
||||
#ifdef LESS_ALLOC_MORE_MEM
|
||||
z.resize(proof.phim);
|
||||
@@ -39,19 +40,18 @@ bool Check_Decoding(const vector<S>& AE, bool Diag)
|
||||
|
||||
|
||||
template <class FD, class S>
|
||||
void Verifier<FD,S>::Stage_2(const vector<int>& e,
|
||||
void Verifier<FD,S>::Stage_2(
|
||||
AddableVector<Ciphertext>& c,octetStream& ciphertexts,
|
||||
octetStream& cleartexts,
|
||||
const FHE_PK& pk,bool Diag,bool binary)
|
||||
{
|
||||
unsigned int i,k,V=P.V;
|
||||
unsigned int i, V;
|
||||
|
||||
c.unpack(ciphertexts, pk);
|
||||
if (c.size() != P.sec)
|
||||
if (c.size() != P.U)
|
||||
throw length_error("number of received ciphertexts incorrect");
|
||||
|
||||
// Now check the encryptions are correct
|
||||
int ee;
|
||||
Ciphertext d1(pk.get_params()), d2(pk.get_params());
|
||||
Random_Coins rc(pk.get_params());
|
||||
ciphertexts.get(V);
|
||||
@@ -67,13 +67,7 @@ void Verifier<FD,S>::Stage_2(const vector<int>& e,
|
||||
if (!P.check_bounds(z, t, i))
|
||||
throw runtime_error("preimage out of bounds");
|
||||
d1.unpack(ciphertexts);
|
||||
for (k=0; k<P.sec; k++)
|
||||
{ int jj=(i+1)-(k+1);
|
||||
if (jj<0 || jj>= (int) P.sec) { ee=0; }
|
||||
else { ee=e[jj]; }
|
||||
if (ee!=0)
|
||||
{ add(d1,d1,c.at(jj)); }
|
||||
}
|
||||
P.apply_challenge(i, d1, c, pk);
|
||||
rc.assign(t[0], t[1], t[2]);
|
||||
pk.encrypt(d2,z,rc);
|
||||
if (!(d1 == d2))
|
||||
@@ -107,13 +101,18 @@ void Verifier<FD,S>::NIZKPoK(AddableVector<Ciphertext>& c,
|
||||
const FHE_PK& pk,bool Diag,
|
||||
bool binary)
|
||||
{
|
||||
vector<int> e(P.sec);
|
||||
P.set_challenge(ciphertexts);
|
||||
|
||||
P.get_challenge(e, ciphertexts);
|
||||
Stage_2(c,ciphertexts,cleartexts,pk,Diag,binary);
|
||||
|
||||
Stage_2(e,c,ciphertexts,cleartexts,pk,Diag,binary);
|
||||
if (P.top_gear)
|
||||
{
|
||||
assert(not Diag);
|
||||
assert(not binary);
|
||||
c += c;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template class Verifier<FFT_Data, bigint>;
|
||||
template class Verifier<P2Data, bigint>;
|
||||
template class Verifier<FFT_Data, fixint<2>>;
|
||||
template class Verifier<P2Data, fixint<2>>;
|
||||
|
||||
@@ -8,14 +8,14 @@ template <class FD, class S>
|
||||
class Verifier
|
||||
{
|
||||
AddableVector<S> z;
|
||||
AddableMatrix<S> t;
|
||||
AddableMatrix<Int_Random_Coins::value_type::value_type> t;
|
||||
|
||||
const Proof& P;
|
||||
Proof& P;
|
||||
|
||||
public:
|
||||
Verifier(const Proof& proof);
|
||||
Verifier(Proof& proof);
|
||||
|
||||
void Stage_2(const vector<int>& e,
|
||||
void Stage_2(
|
||||
AddableVector<Ciphertext>& c, octetStream& ciphertexts,
|
||||
octetStream& cleartexts,const FHE_PK& pk,bool Diag,bool binary=false);
|
||||
|
||||
|
||||
31
GC/BitAdder.h
Normal file
31
GC/BitAdder.h
Normal file
@@ -0,0 +1,31 @@
|
||||
/*
|
||||
* BitAdder.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_BITADDER_H_
|
||||
#define GC_BITADDER_H_
|
||||
|
||||
#include <vector>
|
||||
using namespace std;
|
||||
|
||||
class BitAdder
|
||||
{
|
||||
public:
|
||||
template<class T>
|
||||
void add(vector<vector<T>>& res, const vector<vector<vector<T>>>& summands,
|
||||
SubProcessor<T>& proc, int length, ThreadQueues* queues = 0,
|
||||
int player = -1);
|
||||
|
||||
template<class T>
|
||||
void add(vector<vector<T>>& res, const vector<vector<vector<T>>>& summands,
|
||||
size_t begin, size_t end, SubProcessor<T>& proc, int length,
|
||||
int input_begin = -1, const void* supply = 0);
|
||||
|
||||
template<class T>
|
||||
void multi_add(vector<vector<T>>& res, const vector<vector<vector<T>>>& summands,
|
||||
size_t begin, size_t end, SubProcessor<T>& proc, int length,
|
||||
int input_begin);
|
||||
};
|
||||
|
||||
#endif /* GC_BITADDER_H_ */
|
||||
163
GC/BitAdder.hpp
Normal file
163
GC/BitAdder.hpp
Normal file
@@ -0,0 +1,163 @@
|
||||
/*
|
||||
* BitAdder.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "BitAdder.h"
|
||||
|
||||
#include <assert.h>
|
||||
|
||||
template<class T>
|
||||
void BitAdder::add(vector<vector<T>>& res, const vector<vector<vector<T>>>& summands,
|
||||
SubProcessor<T>& proc, int length, ThreadQueues* queues, int player)
|
||||
{
|
||||
assert(not summands.empty());
|
||||
assert(not summands[0].empty());
|
||||
|
||||
res.resize(summands[0][0].size());
|
||||
|
||||
if (queues)
|
||||
{
|
||||
assert(length == T::default_length);
|
||||
int n_available = queues->find_available();
|
||||
int n_per_thread = queues->get_n_per_thread(res.size());
|
||||
vector<vector<array<T, 3>>> triples(n_available);
|
||||
vector<void*> supplies(n_available);
|
||||
for (int i = 0; i < n_available; i++)
|
||||
{
|
||||
if (T::expensive_triples)
|
||||
{
|
||||
supplies[i] = &triples[i];
|
||||
for (size_t j = 0; j < n_per_thread * summands.size(); j++)
|
||||
triples[i].push_back(proc.DataF.get_triple(T::default_length));
|
||||
#ifdef VERBOSE_EDA
|
||||
cerr << "supplied " << triples[i].size() << endl;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
ThreadJob job(&res, &summands, T::default_length, player);
|
||||
int start = queues->distribute_no_setup(job, res.size(), 0, 1,
|
||||
&supplies);
|
||||
BitAdder().add(res, summands, start,
|
||||
summands[0][0].size(), proc, T::default_length);
|
||||
queues->wrap_up(job);
|
||||
}
|
||||
else
|
||||
add(res, summands, 0, res.size(), proc, length);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void BitAdder::add(vector<vector<T> >& res,
|
||||
const vector<vector<vector<T> > >& summands, size_t begin, size_t end,
|
||||
SubProcessor<T>& proc, int length, int input_begin, const void* supply)
|
||||
{
|
||||
#ifdef VERBOSE_EDA
|
||||
fprintf(stderr, "add bits %lu to %lu\n", begin, end);
|
||||
#endif
|
||||
|
||||
if (input_begin < 0)
|
||||
input_begin = begin;
|
||||
|
||||
int n_bits = summands.size();
|
||||
for (size_t i = begin; i < end; i++)
|
||||
res[i].resize(n_bits + 1);
|
||||
|
||||
size_t n_items = end - begin;
|
||||
|
||||
if (supply)
|
||||
{
|
||||
#ifdef VERBOSE_EDA
|
||||
fprintf(stderr, "got supply\n");
|
||||
#endif
|
||||
auto& s = *(vector<array<T, 3>>*) supply;
|
||||
assert(s.size() == n_items * n_bits);
|
||||
proc.DataF.push_triples(s);
|
||||
}
|
||||
|
||||
if (summands[0].size() > 2)
|
||||
return multi_add(res, summands, begin, end, proc, length, input_begin);
|
||||
|
||||
vector<T> carries(n_items);
|
||||
auto& protocol = proc.protocol;
|
||||
for (int i = 0; i < n_bits; i++)
|
||||
{
|
||||
assert(summands[i].size() == 2);
|
||||
assert(summands[i][0].size() >= input_begin + n_items);
|
||||
assert(summands[i][1].size() >= input_begin + n_items);
|
||||
|
||||
vector<T> a(n_items), b(n_items);
|
||||
for (size_t j = 0; j < n_items; j++)
|
||||
{
|
||||
a[j] = summands[i][0][input_begin + j];
|
||||
b[j] = summands[i][1][input_begin + j];
|
||||
}
|
||||
|
||||
protocol.init_mul(&proc);
|
||||
for (size_t j = 0; j < n_items; j++)
|
||||
{
|
||||
res[begin + j][i] = a[j] + b[j] + carries[j];
|
||||
// full adder using MUX
|
||||
protocol.prepare_mul(a[j] + b[j], carries[j] + a[j], length);
|
||||
}
|
||||
protocol.exchange();
|
||||
for (size_t j = 0; j < n_items; j++)
|
||||
carries[j] = a[j] + protocol.finalize_mul(length);
|
||||
}
|
||||
|
||||
for (size_t j = 0; j < n_items; j++)
|
||||
res[begin + j][n_bits] = carries[j];
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void BitAdder::multi_add(vector<vector<T> >& res,
|
||||
const vector<vector<vector<T> > >& summands, size_t begin, size_t end,
|
||||
SubProcessor<T>& proc, int length, int input_begin)
|
||||
{
|
||||
int n_bits = summands.size() + ceil(log2(proc.P.num_players()));
|
||||
size_t n_items = end - begin;
|
||||
|
||||
assert(n_bits > 0);
|
||||
|
||||
vector<vector<vector<T>>> my_summands(n_bits);
|
||||
|
||||
for (auto& x : my_summands)
|
||||
{
|
||||
x.resize(2);
|
||||
for (auto& y : x)
|
||||
y.resize(n_items);
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < summands.size(); i++)
|
||||
for (int j = 0; j < 2; j++)
|
||||
{
|
||||
auto& x = my_summands.at(i).at(j);
|
||||
auto& z = summands.at(i).at(j);
|
||||
auto y = z.begin() + input_begin;
|
||||
assert(y + n_items <= z.end());
|
||||
x.clear();
|
||||
x.insert(x.begin(), y, y + n_items);
|
||||
}
|
||||
|
||||
vector<vector<T>> my_res(n_items);
|
||||
for (size_t k = 2; k < summands.at(0).size(); k++)
|
||||
{
|
||||
add(my_res, my_summands, 0, n_items, proc, length);
|
||||
|
||||
for (size_t i = 0; i < my_summands.size(); i++)
|
||||
for (size_t j = 0; j < n_items; j++)
|
||||
my_summands[i][0][j] = my_res[j][i];
|
||||
|
||||
for (size_t i = 0; i < summands.size(); i++)
|
||||
{
|
||||
auto& z = summands.at(i).at(k);
|
||||
auto y = z.begin() + input_begin;
|
||||
assert(y + n_items <= z.end());
|
||||
auto& x = my_summands.at(i).at(1);
|
||||
x.clear();
|
||||
x.insert(x.begin(), y, y + n_items);
|
||||
}
|
||||
}
|
||||
|
||||
add(res, my_summands, begin, end, proc, length, 0);
|
||||
}
|
||||
74
GC/CcdPrep.h
Normal file
74
GC/CcdPrep.h
Normal file
@@ -0,0 +1,74 @@
|
||||
/*
|
||||
* CcdPrep.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_CCDPREP_H_
|
||||
#define GC_CCDPREP_H_
|
||||
|
||||
#include "Protocols/ReplicatedPrep.h"
|
||||
|
||||
class DataPositions;
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T> class ShareThread;
|
||||
|
||||
template<class T>
|
||||
class CcdPrep : public BufferPrep<T>
|
||||
{
|
||||
typename T::part_type::LivePrep part_prep;
|
||||
typename T::part_type::MAC_Check part_MC;
|
||||
SubProcessor<typename T::part_type>* part_proc;
|
||||
ShareThread<T>& thread;
|
||||
|
||||
public:
|
||||
CcdPrep(DataPositions& usage, ShareThread<T>& thread) :
|
||||
BufferPrep<T>(usage), part_prep(usage), part_proc(0), thread(thread)
|
||||
{
|
||||
}
|
||||
|
||||
~CcdPrep()
|
||||
{
|
||||
if (part_proc)
|
||||
delete part_proc;
|
||||
}
|
||||
|
||||
void set_protocol(typename T::Protocol& protocol)
|
||||
{
|
||||
part_proc = new SubProcessor<typename T::part_type>(part_MC,
|
||||
part_prep, protocol.get_part().P);
|
||||
}
|
||||
|
||||
Preprocessing<typename T::part_type>& get_part()
|
||||
{
|
||||
return part_prep;
|
||||
}
|
||||
|
||||
void buffer_triples()
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
void buffer_bits()
|
||||
{
|
||||
assert(part_proc);
|
||||
for (int i = 0; i < OnlineOptions::singleton.batch_size; i++)
|
||||
this->bits.push_back(part_prep.get_bit());
|
||||
}
|
||||
|
||||
void buffer_squares()
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
|
||||
void buffer_inverses()
|
||||
{
|
||||
throw not_implemented();
|
||||
}
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif /* GC_CCDPREP_H_ */
|
||||
72
GC/CcdSecret.h
Normal file
72
GC/CcdSecret.h
Normal file
@@ -0,0 +1,72 @@
|
||||
/*
|
||||
* CcdSecret.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_CCDSECRET_H_
|
||||
#define GC_CCDSECRET_H_
|
||||
|
||||
#include "TinySecret.h"
|
||||
#include "CcdShare.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T> class TinyMC;
|
||||
template<class T> class VectorProtocol;
|
||||
template<class T> class VectorInput;
|
||||
template<class T> class CcdPrep;
|
||||
|
||||
template<class T>
|
||||
class CcdSecret : public VectorSecret<CcdShare<T>>
|
||||
{
|
||||
typedef CcdSecret This;
|
||||
typedef VectorSecret<CcdShare<T>> super;
|
||||
|
||||
public:
|
||||
typedef TinyMC<This> MC;
|
||||
typedef MC MAC_Check;
|
||||
typedef VectorProtocol<This> Protocol;
|
||||
typedef CcdPrep<This> LivePrep;
|
||||
typedef VectorInput<This> Input;
|
||||
|
||||
typedef typename This::part_type check_type;
|
||||
|
||||
static string type_short()
|
||||
{
|
||||
return "CCD";
|
||||
}
|
||||
|
||||
static MC* new_mc(typename super::mac_key_type mac_key)
|
||||
{
|
||||
return new MC(mac_key);
|
||||
}
|
||||
|
||||
CcdSecret()
|
||||
{
|
||||
}
|
||||
|
||||
CcdSecret(const typename This::part_type& other) :
|
||||
super(other)
|
||||
{
|
||||
}
|
||||
|
||||
CcdSecret(const super& other) :
|
||||
super(other)
|
||||
{
|
||||
}
|
||||
|
||||
CcdSecret(const typename This::part_type::super& other) :
|
||||
super(other)
|
||||
{
|
||||
}
|
||||
|
||||
CcdSecret(const GC::Clear& other) :
|
||||
super(other)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif /* GC_CCDSECRET_H_ */
|
||||
93
GC/CcdShare.h
Normal file
93
GC/CcdShare.h
Normal file
@@ -0,0 +1,93 @@
|
||||
/*
|
||||
* CcdShare.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_CCDSHARE_H_
|
||||
#define GC_CCDSHARE_H_
|
||||
|
||||
#include "Protocols/ShamirShare.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T> class CcdSecret;
|
||||
|
||||
template<class T>
|
||||
class CcdShare : public ShamirShare<T>, public ShareSecret<CcdSecret<T>>
|
||||
{
|
||||
typedef CcdShare This;
|
||||
|
||||
public:
|
||||
typedef ShamirShare<T> super;
|
||||
|
||||
typedef Bit clear;
|
||||
|
||||
typedef ReplicatedPrep<This> LivePrep;
|
||||
typedef ShamirInput<This> Input;
|
||||
|
||||
typedef ShamirMC<This> MAC_Check;
|
||||
|
||||
typedef This small_type;
|
||||
|
||||
typedef NoShare bit_type;
|
||||
|
||||
static const int default_length = 1;
|
||||
|
||||
static string name()
|
||||
{
|
||||
return "CCD";
|
||||
}
|
||||
|
||||
static MAC_Check* new_mc(T)
|
||||
{
|
||||
return new MAC_Check;
|
||||
}
|
||||
|
||||
static This new_reg()
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
CcdShare()
|
||||
{
|
||||
}
|
||||
|
||||
CcdShare(const CcdSecret<T>& other) :
|
||||
super(other.get_bit(0))
|
||||
{
|
||||
}
|
||||
|
||||
template<class U>
|
||||
CcdShare(const U& other) :
|
||||
super(other)
|
||||
{
|
||||
}
|
||||
|
||||
void XOR(const This& a, const This& b)
|
||||
{
|
||||
*this = a + b;
|
||||
}
|
||||
|
||||
void public_input(bool input)
|
||||
{
|
||||
*this = input;
|
||||
}
|
||||
|
||||
void random()
|
||||
{
|
||||
CcdSecret<T> tmp;
|
||||
ShareThread<CcdSecret<T>>::s().DataF.get_one(DATA_BIT, tmp);
|
||||
*this = tmp.get_reg(0);
|
||||
}
|
||||
|
||||
This& operator^=(const This& other)
|
||||
{
|
||||
*this += other;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif /* GC_CCDSHARE_H_ */
|
||||
@@ -53,8 +53,14 @@ public:
|
||||
{ processor.andrs(args); }
|
||||
static void ands(GC::Processor<FakeSecret>& processor, const vector<int>& regs);
|
||||
template <class T>
|
||||
static void xors(GC::Processor<T>& processor, const vector<int>& regs)
|
||||
{ processor.xors(regs); }
|
||||
template <class T>
|
||||
static void inputb(T& processor, const vector<int>& args)
|
||||
{ processor.input(args); }
|
||||
template <class T>
|
||||
static void reveal_inst(T& processor, const vector<int>& args)
|
||||
{ processor.reveal(args); }
|
||||
|
||||
static void trans(Processor<FakeSecret>& processor, int n_inputs,
|
||||
const vector<int>& args);
|
||||
|
||||
@@ -64,6 +64,8 @@ enum
|
||||
STMSBI = 0x243,
|
||||
MOVSB = 0x244,
|
||||
INPUTB = 0x246,
|
||||
SPLIT = 0x248,
|
||||
CONVCBIT2S = 0x249,
|
||||
// write to clear
|
||||
CLEAR_WRITE = 0x210,
|
||||
XORCBI = 0x210,
|
||||
@@ -79,6 +81,7 @@ enum
|
||||
MULCBI = 0x21c,
|
||||
SHRCBI = 0x21d,
|
||||
SHLCBI = 0x21e,
|
||||
CONVCINTVEC = 0x21f,
|
||||
// don't write
|
||||
PRINTREGSIGNED = 0x220,
|
||||
PRINTREGB = 0x221,
|
||||
@@ -87,6 +90,7 @@ enum
|
||||
CONDPRINTSTRB = 0x224,
|
||||
// write to regint
|
||||
CONVCBIT = 0x230,
|
||||
CONVCBITVEC = 0x231,
|
||||
};
|
||||
|
||||
#endif /* PROCESSOR_GC_INSTRUCTION_H_ */
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_INSTRUCTION_HPP_
|
||||
#define GC_INSTRUCTION_HPP_
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "GC/Instruction.h"
|
||||
@@ -14,6 +17,8 @@
|
||||
|
||||
#include "GC/Instruction_inline.h"
|
||||
|
||||
#include "Processor/Instruction.hpp"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
@@ -97,3 +102,5 @@ void Instruction::parse(istream& s, int pos)
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif
|
||||
|
||||
@@ -11,6 +11,8 @@
|
||||
#include "GC/Program.h"
|
||||
#include "ThreadMaster.h"
|
||||
|
||||
#include "Program.hpp"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
|
||||
69
GC/MaliciousCcdSecret.h
Normal file
69
GC/MaliciousCcdSecret.h
Normal file
@@ -0,0 +1,69 @@
|
||||
/*
|
||||
* MaliciousCcdSecret.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_MALICIOUSCCDSECRET_H_
|
||||
#define GC_MALICIOUSCCDSECRET_H_
|
||||
|
||||
#include "CcdSecret.h"
|
||||
#include "MaliciousCcdShare.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T> class VectorInput;
|
||||
|
||||
template<class T>
|
||||
class MaliciousCcdSecret : public VectorSecret<MaliciousCcdShare<T>>
|
||||
{
|
||||
typedef MaliciousCcdSecret This;
|
||||
typedef VectorSecret<MaliciousCcdShare<T>> super;
|
||||
|
||||
public:
|
||||
typedef TinyMC<This> MC;
|
||||
typedef MC MAC_Check;
|
||||
typedef VectorProtocol<This> Protocol;
|
||||
typedef CcdPrep<This> LivePrep;
|
||||
typedef VectorInput<This> Input;
|
||||
|
||||
typedef typename This::part_type check_type;
|
||||
|
||||
static string type_short()
|
||||
{
|
||||
return "CCD";
|
||||
}
|
||||
|
||||
static MC* new_mc(typename super::mac_key_type mac_key)
|
||||
{
|
||||
return new MC(mac_key);
|
||||
}
|
||||
|
||||
MaliciousCcdSecret()
|
||||
{
|
||||
}
|
||||
|
||||
MaliciousCcdSecret(const typename This::part_type& other) :
|
||||
super(other)
|
||||
{
|
||||
}
|
||||
|
||||
MaliciousCcdSecret(const super& other) :
|
||||
super(other)
|
||||
{
|
||||
}
|
||||
|
||||
MaliciousCcdSecret(const typename This::part_type::super& other) :
|
||||
super(other)
|
||||
{
|
||||
}
|
||||
|
||||
MaliciousCcdSecret(const GC::Clear& other) :
|
||||
super(other)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif /* GC_MALICIOUSCCDSECRET_H_ */
|
||||
104
GC/MaliciousCcdShare.h
Normal file
104
GC/MaliciousCcdShare.h
Normal file
@@ -0,0 +1,104 @@
|
||||
/*
|
||||
* MalicousCcdShare.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_MALICIOUSCCDSHARE_H_
|
||||
#define GC_MALICIOUSCCDSHARE_H_
|
||||
|
||||
#include "CcdShare.h"
|
||||
#include "Protocols/MaliciousShamirShare.h"
|
||||
|
||||
template<class T> class MaliciousRepPrep;
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T> class MaliciousCcdSecret;
|
||||
|
||||
template<class T>
|
||||
class MaliciousCcdShare: public MaliciousShamirShare<T>, public ShareSecret<
|
||||
MaliciousCcdSecret<T>>
|
||||
{
|
||||
typedef MaliciousCcdShare This;
|
||||
|
||||
public:
|
||||
typedef MaliciousShamirShare<T> super;
|
||||
|
||||
typedef Bit clear;
|
||||
|
||||
typedef MaliciousRepPrep<This> LivePrep;
|
||||
typedef ShamirInput<This> Input;
|
||||
typedef Beaver<This> Protocol;
|
||||
|
||||
typedef MaliciousShamirMC<This> MAC_Check;
|
||||
|
||||
typedef This small_type;
|
||||
|
||||
typedef NoShare bit_type;
|
||||
|
||||
static const int default_length = 1;
|
||||
|
||||
static string name()
|
||||
{
|
||||
return "Malicious CCD";
|
||||
}
|
||||
|
||||
static MAC_Check* new_mc(T)
|
||||
{
|
||||
return new MAC_Check;
|
||||
}
|
||||
|
||||
static This new_reg()
|
||||
{
|
||||
return {};
|
||||
}
|
||||
|
||||
MaliciousCcdShare()
|
||||
{
|
||||
}
|
||||
|
||||
MaliciousCcdShare(const MaliciousCcdSecret<T>& other) :
|
||||
super(other.get_bit(0))
|
||||
{
|
||||
}
|
||||
|
||||
template<class U>
|
||||
MaliciousCcdShare(const U& other) :
|
||||
super(other)
|
||||
{
|
||||
}
|
||||
|
||||
void XOR(const This& a, const This& b)
|
||||
{
|
||||
*this = a + b;
|
||||
}
|
||||
|
||||
void public_input(bool input)
|
||||
{
|
||||
*this = input;
|
||||
}
|
||||
|
||||
void random()
|
||||
{
|
||||
MaliciousCcdSecret<T> tmp;
|
||||
ShareThread<MaliciousCcdSecret<T>>::s().DataF.get_one(DATA_BIT, tmp);
|
||||
*this = tmp.get_reg(0);
|
||||
}
|
||||
|
||||
This& operator^=(const This& other)
|
||||
{
|
||||
*this += other;
|
||||
return *this;
|
||||
}
|
||||
|
||||
This get_bit(int i)
|
||||
{
|
||||
assert(i == 0);
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif /* GC_MALICIOUSCCDSHARE_H_ */
|
||||
@@ -21,6 +21,37 @@ namespace GC
|
||||
template<class T> class ShareThread;
|
||||
template<class T> class RepPrep;
|
||||
|
||||
class SmallMalRepSecret : public FixedVec<BitVec_<unsigned char>, 2>
|
||||
{
|
||||
typedef FixedVec<BitVec_<unsigned char>, 2> super;
|
||||
typedef SmallMalRepSecret This;
|
||||
|
||||
public:
|
||||
typedef MaliciousRepMC<This> MC;
|
||||
typedef BitVec_<unsigned char> open_type;
|
||||
typedef open_type clear;
|
||||
typedef BitVec mac_key_type;
|
||||
|
||||
static MC* new_mc(mac_key_type)
|
||||
{
|
||||
return new HashMaliciousRepMC<This>;
|
||||
}
|
||||
|
||||
SmallMalRepSecret()
|
||||
{
|
||||
}
|
||||
template<class T>
|
||||
SmallMalRepSecret(const T& other) :
|
||||
super(other)
|
||||
{
|
||||
}
|
||||
|
||||
This lsb() const
|
||||
{
|
||||
return *this & 1;
|
||||
}
|
||||
};
|
||||
|
||||
class MaliciousRepSecret : public ReplicatedSecret<MaliciousRepSecret>
|
||||
{
|
||||
typedef ReplicatedSecret<MaliciousRepSecret> super;
|
||||
@@ -36,6 +67,11 @@ public:
|
||||
typedef RepPrep<MaliciousRepSecret> LivePrep;
|
||||
|
||||
typedef MaliciousRepSecret part_type;
|
||||
typedef MaliciousRepSecret whole_type;
|
||||
|
||||
typedef SmallMalRepSecret small_type;
|
||||
|
||||
static const bool expensive_triples = true;
|
||||
|
||||
static MC* new_mc(mac_key_type)
|
||||
{
|
||||
|
||||
40
GC/NoShare.h
40
GC/NoShare.h
@@ -12,10 +12,11 @@
|
||||
namespace GC
|
||||
{
|
||||
|
||||
class NoValue
|
||||
class NoValue : public ValueInterface
|
||||
{
|
||||
public:
|
||||
const static int n_bits = 0;
|
||||
const static int MAX_N_BITS = 0;
|
||||
|
||||
static bool allows(Dtype)
|
||||
{
|
||||
@@ -32,11 +33,33 @@ public:
|
||||
throw runtime_error("VM does not support binary circuits");
|
||||
}
|
||||
|
||||
NoValue() {}
|
||||
NoValue(int) { fail(); }
|
||||
|
||||
void assign(const char*) { fail(); }
|
||||
|
||||
int get() const { fail(); return 0; }
|
||||
|
||||
int operator<<(int) const { fail(); return 0; }
|
||||
void operator+=(int) { fail(); }
|
||||
|
||||
bool get_bit(int) { fail(); return 0; }
|
||||
|
||||
void randomize(PRNG&) { fail(); }
|
||||
};
|
||||
|
||||
inline ostream& operator<<(ostream& o, NoValue)
|
||||
{
|
||||
return o;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
inline bool operator!=(const T&, NoValue&)
|
||||
{
|
||||
NoValue::fail();
|
||||
return true;
|
||||
}
|
||||
|
||||
class NoShare : public Phase
|
||||
{
|
||||
public:
|
||||
@@ -52,8 +75,13 @@ public:
|
||||
|
||||
typedef NoShare bit_type;
|
||||
typedef NoShare part_type;
|
||||
typedef NoShare small_type;
|
||||
|
||||
static const int default_length = 1;
|
||||
|
||||
static const bool needs_ot = false;
|
||||
static const bool expensive_triples = false;
|
||||
static const bool is_real = false;
|
||||
|
||||
static MC* new_mc(mac_key_type)
|
||||
{
|
||||
@@ -91,10 +119,13 @@ public:
|
||||
}
|
||||
|
||||
static void inputb(Processor<NoShare>&, const vector<int>&) { fail(); }
|
||||
static void reveal_inst(Processor<NoShare>&, const vector<int>&) { fail(); }
|
||||
|
||||
static void input(Processor<NoShare>&, InputArgs&) { fail(); }
|
||||
static void trans(Processor<NoShare>&, Integer, const vector<int>&) { fail(); }
|
||||
|
||||
static NoShare constant(GC::Clear, int, mac_key_type) { fail(); return {}; }
|
||||
|
||||
NoShare() {}
|
||||
|
||||
NoShare(int) { fail(); }
|
||||
@@ -115,6 +146,13 @@ public:
|
||||
void operator^=(NoShare) { fail(); }
|
||||
|
||||
NoShare operator+(const NoShare&) const { fail(); return {}; }
|
||||
|
||||
NoShare operator+(int) const { fail(); return {}; }
|
||||
NoShare operator&(int) const { fail(); return {}; }
|
||||
NoShare operator>>(int) const { fail(); return {}; }
|
||||
|
||||
NoShare lsb() const { fail(); return {}; }
|
||||
NoShare get_bit(int) const { fail(); return {}; }
|
||||
};
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
35
GC/PersonalPrep.h
Normal file
35
GC/PersonalPrep.h
Normal file
@@ -0,0 +1,35 @@
|
||||
/*
|
||||
* PersonalPrep.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_PERSONALPREP_H_
|
||||
#define GC_PERSONALPREP_H_
|
||||
|
||||
#include "Protocols/ReplicatedPrep.h"
|
||||
#include "ShareThread.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T>
|
||||
class PersonalPrep : public BufferPrep<T>
|
||||
{
|
||||
protected:
|
||||
static const int SECURE = -1;
|
||||
|
||||
const int input_player;
|
||||
|
||||
void buffer_personal_triples();
|
||||
|
||||
public:
|
||||
PersonalPrep(DataPositions& usage, int input_player);
|
||||
|
||||
void buffer_personal_triples(size_t n, ThreadQueues* queues = 0);
|
||||
void buffer_personal_triples(vector<array<T, 3>>& triples, size_t begin,
|
||||
size_t end);
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
#endif /* GC_PERSONALPREP_H_ */
|
||||
109
GC/PersonalPrep.hpp
Normal file
109
GC/PersonalPrep.hpp
Normal file
@@ -0,0 +1,109 @@
|
||||
/*
|
||||
* PersonalPrep.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_PERSONALPREP_HPP_
|
||||
#define GC_PERSONALPREP_HPP_
|
||||
|
||||
#include "PersonalPrep.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
template<class T>
|
||||
PersonalPrep<T>::PersonalPrep(DataPositions& usage, int input_player) :
|
||||
BufferPrep<T>(usage), input_player(input_player)
|
||||
{
|
||||
assert((input_player >= 0) or (input_player == SECURE));
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void PersonalPrep<T>::buffer_personal_triples()
|
||||
{
|
||||
buffer_personal_triples(OnlineOptions::singleton.batch_size);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void PersonalPrep<T>::buffer_personal_triples(size_t batch_size, ThreadQueues* queues)
|
||||
{
|
||||
ShuffleSacrifice<T> sacri;
|
||||
batch_size = max(batch_size, (size_t)sacri.minimum_n_outputs()) + sacri.C;
|
||||
vector<array<T, 3>> triples(batch_size);
|
||||
|
||||
if (queues)
|
||||
{
|
||||
PersonalTripleJob job(&triples, input_player);
|
||||
int start = queues->distribute(job, batch_size);
|
||||
buffer_personal_triples(triples, start, batch_size);
|
||||
queues->wrap_up(job);
|
||||
}
|
||||
else
|
||||
buffer_personal_triples(triples, 0, batch_size);
|
||||
|
||||
auto &party = ShareThread<typename T::whole_type>::s();
|
||||
assert(party.P != 0);
|
||||
assert(party.MC != 0);
|
||||
auto& MC = party.MC->get_part_MC();
|
||||
auto& P = *party.P;
|
||||
GlobalPRNG G(P);
|
||||
vector<T> shares;
|
||||
for (int i = 0; i < sacri.C; i++)
|
||||
{
|
||||
int challenge = G.get_uint(triples.size());
|
||||
for (auto& x : triples[challenge])
|
||||
shares.push_back(x);
|
||||
triples.erase(triples.begin() + challenge);
|
||||
}
|
||||
PointerVector<typename T::open_type> opened;
|
||||
MC.POpen(opened, shares, P);
|
||||
for (int i = 0; i < sacri.C; i++)
|
||||
{
|
||||
array<typename T::open_type, 3> triple({{opened.next(), opened.next(),
|
||||
opened.next()}});
|
||||
if (triple[0] * triple[1] != triple[2])
|
||||
{
|
||||
cout << triple[2] << " != " << triple[0] * triple[1] << " = "
|
||||
<< triple[0] << " * " << triple[1] << endl;
|
||||
throw runtime_error("personal triple incorrect");
|
||||
}
|
||||
}
|
||||
|
||||
this->triples.insert(this->triples.end(), triples.begin(), triples.end());
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void PersonalPrep<T>::buffer_personal_triples(vector<array<T, 3>>& triples,
|
||||
size_t begin, size_t end)
|
||||
{
|
||||
#ifdef VERBOSE_EDA
|
||||
fprintf(stderr, "personal triples %zu to %zu\n", begin, end);
|
||||
RunningTimer timer;
|
||||
#endif
|
||||
auto& party = ShareThread<typename T::whole_type>::s();
|
||||
auto& MC = party.MC->get_part_MC();
|
||||
auto& P = *party.P;
|
||||
assert(input_player < P.num_players());
|
||||
typename T::Input input(MC, *this, P);
|
||||
input.reset_all(P);
|
||||
for (size_t i = begin; i < end; i++)
|
||||
{
|
||||
typename T::clear x[2];
|
||||
for (int j = 0; j < 2; j++)
|
||||
this->get_input(triples[i][j], x[j], input_player);
|
||||
if (P.my_num() == input_player)
|
||||
input.add_mine(x[0] * x[1], T::default_length);
|
||||
else
|
||||
input.add_other(input_player);
|
||||
}
|
||||
input.exchange();
|
||||
for (size_t i = begin; i < end; i++)
|
||||
triples[i][2] = input.finalize(input_player, T::default_length);
|
||||
#ifdef VERBOSE_EDA
|
||||
fprintf(stderr, "personal triples took %f seconds\n", timer.elapsed());
|
||||
#endif
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -14,6 +14,7 @@ using namespace std;
|
||||
|
||||
#include "Math/Integer.h"
|
||||
#include "Processor/ProcessorBase.h"
|
||||
#include "Processor/Instruction.h"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
@@ -84,11 +85,15 @@ public:
|
||||
void store_clear_in_dynamic(const vector<int>& args, U& dynamic_memory);
|
||||
|
||||
void xors(const vector<int>& args);
|
||||
void andm(const ::BaseInstruction& instruction);
|
||||
void and_(const vector<int>& args, bool repeat);
|
||||
void andrs(const vector<int>& args) { and_(args, true); }
|
||||
void ands(const vector<int>& args) { and_(args, false); }
|
||||
|
||||
void input(const vector<int>& args);
|
||||
void reveal(const vector<int>& args);
|
||||
|
||||
void reveal(const ::BaseInstruction& instruction);
|
||||
|
||||
void print_reg(int reg, int n);
|
||||
void print_reg_plain(Clear& value);
|
||||
|
||||
@@ -185,6 +185,14 @@ void Processor<T>::xors(const vector<int>& args)
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Processor<T>::andm(const ::BaseInstruction& instruction)
|
||||
{
|
||||
for (int i = 0; i < DIV_CEIL(instruction.get_n(), T::default_length); i++)
|
||||
S[instruction.get_r(0) + i] = S[instruction.get_r(1) + i]
|
||||
& C[instruction.get_r(2) + i];
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Processor<T>::and_(const vector<int>& args, bool repeat)
|
||||
{
|
||||
@@ -209,6 +217,22 @@ void Processor<T>::input(const vector<int>& args)
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Processor<T>::reveal(const vector<int>& args)
|
||||
{
|
||||
for (size_t j = 0; j < args.size(); j += 3)
|
||||
{
|
||||
int n = args[j];
|
||||
int r0 = args[j + 1];
|
||||
int r1 = args[j + 2];
|
||||
if (n > max(T::default_length, Clear::N_BITS))
|
||||
assert(T::default_length == Clear::N_BITS);
|
||||
for (int i = 0; i < DIV_CEIL(n, T::default_length); i++)
|
||||
S[r1 + i].reveal(min(Clear::N_BITS, n - i * Clear::N_BITS),
|
||||
C[r0 + i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void Processor<T>::print_reg(int reg, int n)
|
||||
{
|
||||
@@ -232,6 +256,8 @@ void Processor<T>::print_reg_signed(unsigned n_bits, Clear& value)
|
||||
unsigned n_shift = 0;
|
||||
if (n_bits > 1)
|
||||
n_shift = sizeof(value.get()) * 8 - n_bits;
|
||||
if (n_shift > 63)
|
||||
n_shift = 0;
|
||||
T::out << dec << (value.get() << n_shift >> n_shift) << flush;
|
||||
}
|
||||
|
||||
|
||||
@@ -3,12 +3,17 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef GC_PROGRAM_HPP_
|
||||
#define GC_PROGRAM_HPP_
|
||||
|
||||
#include <GC/Program.h>
|
||||
|
||||
#include "config.h"
|
||||
|
||||
#include "Tools/callgrind.h"
|
||||
|
||||
#include "Instruction.hpp"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
@@ -133,3 +138,5 @@ BreakType Program<T>::execute(Processor<T>& Proc, U& dynamic_memory,
|
||||
}
|
||||
|
||||
} /* namespace GC */
|
||||
|
||||
#endif
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include "MaliciousRepSecret.h"
|
||||
#include "ShiftableTripleBuffer.h"
|
||||
#include "PersonalPrep.h"
|
||||
#include "Protocols/ReplicatedPrep.h"
|
||||
|
||||
namespace GC
|
||||
@@ -16,13 +17,13 @@ namespace GC
|
||||
template<class T> class ShareThread;
|
||||
|
||||
template<class T>
|
||||
class RepPrep : public BufferPrep<T>, ShiftableTripleBuffer<T>
|
||||
class RepPrep : public PersonalPrep<T>, ShiftableTripleBuffer<T>
|
||||
{
|
||||
ReplicatedBase* protocol;
|
||||
|
||||
public:
|
||||
RepPrep(DataPositions& usage, ShareThread<T>& thread);
|
||||
RepPrep(DataPositions& usage);
|
||||
RepPrep(DataPositions& usage, int input_player = PersonalPrep<T>::SECURE);
|
||||
~RepPrep();
|
||||
|
||||
void set_protocol(typename T::Protocol& protocol);
|
||||
@@ -33,6 +34,8 @@ public:
|
||||
void buffer_squares() { throw not_implemented(); }
|
||||
void buffer_inverses() { throw not_implemented(); }
|
||||
|
||||
void buffer_inputs(int player);
|
||||
|
||||
void get(Dtype type, T* data)
|
||||
{
|
||||
BufferPrep<T>::get(type, data);
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user