SoftSpokenOT.

This commit is contained in:
Marcel Keller
2022-08-25 13:20:46 +10:00
parent e08a6adb63
commit 6a424539c9
171 changed files with 2181 additions and 1025 deletions

4
.gitignore vendored
View File

@@ -119,3 +119,7 @@ _build/
# environment
.env
# temp doc files
doc/readme.md
doc/xml

12
.gitmodules vendored
View File

@@ -1,12 +1,18 @@
[submodule "SimpleOT"]
path = SimpleOT
path = deps/SimpleOT
url = https://github.com/mkskeller/SimpleOT
[submodule "mpir"]
path = mpir
path = deps/mpir
url = https://github.com/wbhart/mpir
[submodule "Programs/Circuits"]
path = Programs/Circuits
url = https://github.com/mkskeller/bristol-fashion
[submodule "simde"]
path = simde
path = deps/simde
url = https://github.com/simd-everywhere/simde
[submodule "deps/libOTe"]
path = deps/libOTe
url = https://github.com/mkskeller/softspoken-implementation
[submodule "deps/SimplestOT_C"]
path = deps/SimplestOT_C
url = https://github.com/mkskeller/SimplestOT_C

View File

@@ -249,6 +249,7 @@ FakeProgramParty::FakeProgramParty(int argc, const char** argv) :
}
cout << "Compiler: " << prev << endl;
P = new PlainPlayer(N, 0);
Share<gf2n_long>::MAC_Check::setup(*P);
if (argc > 4)
threshold = atoi(argv[4]);
cout << "Threshold for multi-threaded evaluation: " << threshold << endl;
@@ -280,6 +281,7 @@ FakeProgramParty::~FakeProgramParty()
cerr << "Dynamic storage: " << 1e-9 * dynamic_memory.capacity_in_bytes()
<< " GB" << endl;
#endif
Share<gf2n_long>::MAC_Check::teardown();
}
void FakeProgramParty::_compute_prfs_outputs(Key* keys)

View File

@@ -48,8 +48,6 @@ public:
static void inputbvec(GC::Processor<GC::Secret<RealGarbleWire>>& processor,
ProcessorBase& input_processor, const vector<int>& args);
RealGarbleWire(const Register& reg) : PRFRegister(reg) {}
void garble(PRFOutputs& prf_output, const RealGarbleWire<T>& left,
const RealGarbleWire<T>& right);

View File

@@ -110,7 +110,7 @@ void RealGarbleWire<T>::inputbvec(
{
GarbleInputter<T> inputter;
processor.inputbvec(inputter, input_processor, args,
inputter.party.P->my_num());
*inputter.party.P);
}
template<class T>

View File

@@ -97,8 +97,6 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
if (online_opts.live_prep)
{
mac_key.randomize(prng);
if (T::needs_ot)
BaseMachine::s().ot_setups.push_back({*P, true});
prep = new typename T::LivePrep(0, usage);
}
else
@@ -107,6 +105,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
prep = new Sub_Data_Files<T>(N, prep_dir, usage);
}
T::MAC_Check::setup(*P);
MC = new typename T::MAC_Check(mac_key);
garble_processor.reset(program);
@@ -219,6 +218,7 @@ RealProgramParty<T>::~RealProgramParty()
delete garble_inputter;
delete garble_protocol;
cout << "Data sent = " << data_sent * 1e-6 << " MB" << endl;
T::MAC_Check::teardown();
}
template<class T>

View File

@@ -152,7 +152,7 @@ public:
* for pipelining matters.
*/
Register(int n_parties);
Register();
void init(int n_parties);
void init(int rfd, int n_parties);
@@ -278,10 +278,6 @@ public:
static int threshold(int) { throw not_implemented(); }
static Register new_reg();
static Register tmp_reg() { return new_reg(); }
static Register and_reg() { return new_reg(); }
template<class T>
static void store(NoMemory& dest,
const vector<GC::WriteAccess<T> >& accesses) { (void)dest; (void)accesses; }
@@ -306,8 +302,6 @@ public:
void other_input(Input&, int) {}
char get_output() { return 0; }
ProgramRegister(const Register& reg) : Register(reg) {}
};
class PRFRegister : public ProgramRegister
@@ -319,8 +313,6 @@ public:
static void load(vector<GC::ReadAccess<T> >& accesses,
const NoMemory& source);
PRFRegister(const Register& reg) : ProgramRegister(reg) {}
void op(const PRFRegister& left, const PRFRegister& right, Function func);
void XOR(const Register& left, const Register& right);
void input(party_id_t from, char input = -1);
@@ -396,8 +388,6 @@ public:
static void convcbit(Integer& dest, const GC::Clear& source,
GC::Processor<GC::Secret<EvalRegister>>& proc);
EvalRegister(const Register& reg) : ProgramRegister(reg) {}
void op(const ProgramRegister& left, const ProgramRegister& right, Function func);
void XOR(const Register& left, const Register& right);
@@ -427,8 +417,6 @@ public:
static void load(vector<GC::ReadAccess<T> >& accesses,
const NoMemory& source);
GarbleRegister(const Register& reg) : ProgramRegister(reg) {}
void op(const Register& left, const Register& right, Function func);
void XOR(const Register& left, const Register& right);
void input(party_id_t from, char value = -1);
@@ -452,8 +440,6 @@ public:
static void load(vector<GC::ReadAccess<T> >& accesses,
const NoMemory& source);
RandomRegister(const Register& reg) : ProgramRegister(reg) {}
void randomize();
void op(const Register& left, const Register& right, Function func);
@@ -469,12 +455,6 @@ public:
};
inline Register::Register(int n_parties) :
garbled_entry(n_parties), external(NO_SIGNAL),
mask(NO_SIGNAL), keys(n_parties)
{
}
inline void KeyVector::operator=(const KeyVector& other)
{
resize(other.size());

View File

@@ -14,15 +14,7 @@ void ProgramRegister::inputbvec(T& processor, ProcessorBase& input_processor,
const vector<int>& args)
{
NoOpInputter inputter;
int my_num = -1;
try
{
my_num = ProgramParty::s().P->my_num();
}
catch (exception&)
{
}
processor.inputbvec(inputter, input_processor, args, my_num);
processor.inputbvec(inputter, input_processor, args, *ProgramParty::s().P);
}
template<class T>
@@ -31,7 +23,7 @@ void EvalRegister::inputbvec(T& processor, ProcessorBase& input_processor,
{
EvalInputter inputter;
processor.inputbvec(inputter, input_processor, args,
ProgramParty::s().P->my_num());
*ProgramParty::s().P);
}
template <class T>

View File

@@ -9,10 +9,10 @@
#include "CommonParty.h"
#include "Party.h"
inline Register ProgramRegister::new_reg()
inline Register::Register() :
garbled_entry(CommonParty::s().get_n_parties()), external(NO_SIGNAL),
mask(NO_SIGNAL), keys(CommonParty::s().get_n_parties())
{
return Register(CommonParty::s().get_n_parties());
}
#endif /* BMR_REGISTER_INLINE_H_ */

View File

@@ -1,5 +1,20 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.3.3 (Aug 25, 2022)
- Use SoftSpokenOT to avoid unclear security of KOS OT extension candidate
- Fix security bug in MAC check when using multithreading
- Fix security bug to prevent selective failure attack by checking earlier
- Fix security bug in Mama: insufficient sacrifice.
- Inverse permutation (@Quitlox)
- Easier direct compilation (@eriktaubeneck)
- Generally allow element-vector operations
- Increase maximum register size to 2^54
- Client example in Python
- Uniform base OTs across platforms
- Multithreaded base OT computation
- Faster random bit generation in two-player Semi(2k)
## 0.3.2 (May 27, 2022)
- Secure shuffling
@@ -7,7 +22,7 @@ The changelog explains changes pulled through from the private development repos
- Documented BGV encryption interface
- Optimized matrix multiplication in dealer protocol
- Fixed security bug in homomorphic encryption parameter generation
- Fixed Security bug in Temi matrix multiplication
- Fixed security bug in Temi matrix multiplication
## 0.3.1 (Apr 19, 2022)

18
CONFIG
View File

@@ -31,24 +31,21 @@ ARCH = -mtune=native -msse4.1 -msse4.2 -maes -mpclmul -mavx -mavx2 -mbmi2 -madx
ARCH = -march=native
MACHINE := $(shell uname -m)
ARM := $(shell uname -m | grep x86; echo $$?)
OS := $(shell uname -s)
ifeq ($(MACHINE), x86_64)
# set this to 0 to avoid using AVX for OT
ifeq ($(OS), Linux)
CHECK_AVX := $(shell grep -q avx /proc/cpuinfo; echo $$?)
ifeq ($(CHECK_AVX), 0)
AVX_OT = 1
else
AVX_OT = 0
endif
else
AVX_OT = 1
endif
else
ARCH =
AVX_OT = 0
endif
USE_KOS = 0
# allow to set compiler in CONFIG.mine
CXX = g++
@@ -87,7 +84,7 @@ else
BOOST = -lboost_thread $(MY_BOOST)
endif
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror
CPPFLAGS = $(CFLAGS)
LD = $(CXX)
@@ -98,3 +95,10 @@ ifeq ($(USE_NTL),1)
CFLAGS += -Wno-error=unused-parameter -Wno-error=deprecated-copy
endif
endif
ifeq ($(USE_KOS),1)
CFLAGS += -DUSE_KOS
else
CFLAGS += -std=c++17
LDLIBS += -llibOTe -lcryptoTools
endif

View File

@@ -342,7 +342,8 @@ class stmcb(base.DirectMemoryWriteInstruction, base.VectorInstruction):
code = opcodes['STMCB']
arg_format = ['cb','long']
class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction):
class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction,
base.IndirectMemoryInstruction):
""" Copy secret bit memory cell with run-time address to secret bit
register.
@@ -351,8 +352,10 @@ class ldmsbi(base.ReadMemoryInstruction, base.VectorInstruction):
"""
code = opcodes['LDMSBI']
arg_format = ['sbw','ci']
direct = staticmethod(ldmsb)
class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction):
class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction,
base.IndirectMemoryInstruction):
""" Copy secret bit register to secret bit memory cell with run-time
address.
@@ -361,8 +364,10 @@ class stmsbi(base.WriteMemoryInstruction, base.VectorInstruction):
"""
code = opcodes['STMSBI']
arg_format = ['sb','ci']
direct = staticmethod(stmsb)
class ldmcbi(base.ReadMemoryInstruction, base.VectorInstruction):
class ldmcbi(base.ReadMemoryInstruction, base.VectorInstruction,
base.IndirectMemoryInstruction):
""" Copy clear bit memory cell with run-time address to clear bit
register.
@@ -371,8 +376,10 @@ class ldmcbi(base.ReadMemoryInstruction, base.VectorInstruction):
"""
code = opcodes['LDMCBI']
arg_format = ['cbw','ci']
direct = staticmethod(ldmcb)
class stmcbi(base.WriteMemoryInstruction, base.VectorInstruction):
class stmcbi(base.WriteMemoryInstruction, base.VectorInstruction,
base.IndirectMemoryInstruction):
""" Copy clear bit register to clear bit memory cell with run-time
address.
@@ -381,6 +388,7 @@ class stmcbi(base.WriteMemoryInstruction, base.VectorInstruction):
"""
code = opcodes['STMCBI']
arg_format = ['cb','ci']
direct = staticmethod(stmcb)
class ldmsdi(base.ReadMemoryInstruction):
code = opcodes['LDMSDI']

View File

@@ -198,6 +198,8 @@ class bits(Tape.Register, _structure, _bit):
return 0
elif self.is_long_one(other):
return self
elif isinstance(other, _vec):
return other & other.from_vec([self])
else:
return self._and(other)
@read_mem_value
@@ -241,6 +243,13 @@ class bits(Tape.Register, _structure, _bit):
return self * condition
else:
return self * cbit.conv(condition)
def expand(self, length):
if self.n in (length, None):
return self
elif self.n == 1:
return self.get_type(length).bit_compose([self] * length)
else:
raise CompilerError('cannot expand from %s to %s' % (self.n, length))
class cbits(bits):
""" Clear bits register. Helper type with limited functionality. """
@@ -295,8 +304,15 @@ class cbits(bits):
return op(self, cbits(other))
__add__ = lambda self, other: \
self.clear_op(other, inst.addcb, inst.addcbi, operator.add)
__sub__ = lambda self, other: \
self.clear_op(-other, inst.addcb, inst.addcbi, operator.add)
def __sub__(self, other):
try:
return self + -other
except:
return type(self)(regint(self) - regint(other))
def __rsub__(self, other):
return type(self)(other - regint(self))
def __neg__(self):
return type(self)(-regint(self))
def _xor(self, other):
if isinstance(other, (sbits, sbitvec)):
return NotImplemented
@@ -589,7 +605,15 @@ class sbits(bits):
rows = list(rows)
if len(rows) == 1 and rows[0].n <= rows[0].unit:
return rows[0].bit_decompose()
n_columns = rows[0].n
for row in rows:
try:
n_columns = row.n
break
except:
pass
for i in range(len(rows)):
if util.is_zero(rows[i]):
rows[i] = cls.get_type(n_columns)(0)
for row in rows:
assert(row.n == n_columns)
if n_columns == 1 and len(rows) <= cls.unit:
@@ -613,7 +637,7 @@ class sbits(bits):
def ripple_carry_adder(*args, **kwargs):
return sbitint.ripple_carry_adder(*args, **kwargs)
class sbitvec(_vec):
class sbitvec(_vec, _bit):
""" Vector of registers of secret bits, effectively a matrix of secret bits.
This facilitates parallel arithmetic operations in binary circuits.
Container types are not supported, use :py:obj:`sbitvec.get_type` for that.
@@ -656,6 +680,7 @@ class sbitvec(_vec):
[1, 0, 1]
"""
bit_extend = staticmethod(lambda v, n: v[:n] + [0] * (n - len(v)))
is_clear = False
@classmethod
def get_type(cls, n):
""" Create type for fixed-length vector of registers of secret bits.
@@ -691,10 +716,11 @@ class sbitvec(_vec):
res.v = _complement_two_extend(list(vector), n)[:n]
return res
def __init__(self, other=None, size=None):
assert size in (None, 1)
if other is not None:
if util.is_constant(other):
self.v = [sbit((other >> i) & 1) for i in range(n)]
t = sbits.get_type(size or 1)
self.v = [t(((other >> i) & 1) * ((1 << t.n) - 1))
for i in range(n)]
elif isinstance(other, _vec):
self.v = self.bit_extend(other.v, n)
elif isinstance(other, (list, tuple)):
@@ -702,6 +728,7 @@ class sbitvec(_vec):
else:
self.v = sbits.get_type(n)(other).bit_decompose()
assert len(self.v) == n
assert size is None or size == self.v[0].n
@classmethod
def load_mem(cls, address, size=None):
if size not in (None, 1):
@@ -733,8 +760,9 @@ class sbitvec(_vec):
def reveal(self):
return util.untuplify([x.reveal() for x in self.elements()])
@classmethod
def two_power(cls, nn):
return cls.from_vec([0] * nn + [1] + [0] * (n - nn - 1))
def two_power(cls, nn, size=1):
return cls.from_vec(
[0] * nn + [sbits.get_type(size)().long_one()] + [0] * (n - nn - 1))
def coerce(self, other):
if util.is_constant(other):
return self.from_vec(util.bit_decompose(other, n))
@@ -818,16 +846,14 @@ class sbitvec(_vec):
return other
def __xor__(self, other):
other = self.coerce(other)
return self.from_vec(x ^ y for x, y in zip(self.v, other))
return self.from_vec(x ^ y for x, y in zip(*self.expand(other)))
def __and__(self, other):
return self.from_vec(x & y for x, y in zip(self.v, other.v))
return self.from_vec(x & y for x, y in zip(*self.expand(other)))
def __invert__(self):
return self.from_vec(~x for x in self.v)
def if_else(self, x, y):
assert(len(self.v) == 1)
try:
return self.from_vec(util.if_else(self.v[0], a, b) \
for a, b in zip(x, y))
except:
return util.if_else(self.v[0], x, y)
return util.if_else(self.v[0], x, y)
def __iter__(self):
return iter(self.v)
def __len__(self):
@@ -890,6 +916,24 @@ class sbitvec(_vec):
elements = red.elements()
elements += odd
return self.from_vec(sbitvec(elements).v)
@classmethod
def comp_result(cls, x):
return cls.get_type(1).from_vec([x])
def expand(self, other, expand=True):
m = 1
for x in itertools.chain(self.v, other.v if isinstance(other, sbitvec) else []):
try:
m = max(m, x.n)
except:
pass
res = []
for y in self, other:
if isinstance(y, int):
res.append([x * sbits.get_type(m)().long_one()
for x in util.bit_decompose(y, len(self.v))])
else:
res.append([x.expand(m) if (expand and isinstance(x, bits)) else x for x in y.v])
return res
class bit(object):
n = 1
@@ -1139,7 +1183,7 @@ class sbitint(_bitint, _number, sbits, _sbitintbase):
:param k: bit length of input """
return _sbitintbase.pow2(self, k)
class sbitintvec(sbitvec, _number, _bitint, _sbitintbase):
class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
"""
Vector of signed integers for parallel binary computation::
@@ -1176,7 +1220,8 @@ class sbitintvec(sbitvec, _number, _bitint, _sbitintbase):
return self
other = self.coerce(other)
assert(len(self.v) == len(other.v))
v = sbitint.bit_adder(self.v, other.v)
a, b = self.expand(other)
v = sbitint.bit_adder(a, b)
return self.from_vec(v)
__radd__ = __add__
def __mul__(self, other):
@@ -1184,7 +1229,7 @@ class sbitintvec(sbitvec, _number, _bitint, _sbitintbase):
return self.from_vec(other * x for x in self.v)
elif isinstance(other, sbitfixvec):
return NotImplemented
other_bits = util.bit_decompose(other)
_, other_bits = self.expand(other, False)
m = float('inf')
for x in itertools.chain(self.v, other_bits):
try:
@@ -1228,6 +1273,8 @@ class cbitfix(object):
store_in_mem = lambda self, *args: self.v.store_in_mem(*args)
@classmethod
def _new(cls, value):
if isinstance(value, list):
return [cls._new(x) for x in value]
res = cls()
if cls.k < value.unit:
bits = value.bit_decompose(cls.k)

View File

@@ -87,15 +87,14 @@ def LtzRing(a, k):
carry = CarryOutRawLE(*reversed(list(x[:-1] for x in summands)))
msb = carry ^ summands[0][-1] ^ summands[1][-1]
return sint.conv(msb)
return
elif program.options.ring:
else:
from . import floatingpoint
require_ring_size(k, 'comparison')
m = k - 1
shift = int(program.options.ring) - k
r_prime, r_bin = MaskingBitsInRing(k)
tmp = a - r_prime
c_prime = (tmp << shift).reveal() >> shift
c_prime = (tmp << shift).reveal(False) >> shift
a = r_bin[0].bit_decompose_clear(c_prime, m)
b = r_bin[:m]
u = CarryOutRaw(a[::-1], b[::-1])
@@ -190,7 +189,7 @@ def TruncLeakyInRing(a, k, m, signed):
r = sint.bit_compose(r_bits)
if signed:
a += (1 << (k - 1))
shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal()
shifted = ((a << (n_shift - m)) + (r << n_shift)).reveal(False)
masked = shifted >> n_shift
u = sint()
BitLTL(u, masked, r_bits[:n_bits], 0)
@@ -231,7 +230,7 @@ def Mod2mRing(a_prime, a, k, m, signed):
shift = int(program.options.ring) - m
r_prime, r_bin = MaskingBitsInRing(m, True)
tmp = a + r_prime
c_prime = (tmp << shift).reveal() >> shift
c_prime = (tmp << shift).reveal(False) >> shift
u = sint()
BitLTL(u, c_prime, r_bin[:m], 0)
res = (u << m) + c_prime - r_prime
@@ -261,7 +260,7 @@ def Mod2mField(a_prime, a, k, m, kappa, signed):
t[1] = a
adds(t[2], t[0], t[1])
adds(t[3], t[2], r_prime)
asm_open(c, t[3])
asm_open(True, c, t[3])
modc(c_prime, c, c2m)
if const_rounds:
BitLTC1(u, c_prime, r, kappa)
@@ -510,7 +509,7 @@ def PreMulC_with_inverses_and_vectors(p, a):
movs(w[0], r[0])
movs(a_vec[0], a[0])
vmuls(k, t[0], w, a_vec)
vasm_open(k, m, t[0])
vasm_open(k, True, m, t[0])
PreMulC_end(p, a, c, m, z)
def PreMulC_with_inverses(p, a):
@@ -538,7 +537,7 @@ def PreMulC_with_inverses(p, a):
w[1][0] = r[0][0]
for i in range(k):
muls(t[0][i], w[1][i], a[i])
asm_open(m[i], t[0][i])
asm_open(True, m[i], t[0][i])
PreMulC_end(p, a, c, m, z)
def PreMulC_without_inverses(p, a):
@@ -563,7 +562,7 @@ def PreMulC_without_inverses(p, a):
#adds(tt[0][i], t[0][i], a[i])
#subs(tt[1][i], tt[0][i], a[i])
#startopen(tt[1][i])
asm_open(u[i], t[0][i])
asm_open(True, u[i], t[0][i])
for i in range(k-1):
muls(v[i], r[i+1], s[i])
w[0] = r[0]
@@ -579,7 +578,7 @@ def PreMulC_without_inverses(p, a):
mulm(z[i], s[i], u_inv[i])
for i in range(k):
muls(t[1][i], w[i], a[i])
asm_open(m[i], t[1][i])
asm_open(True, m[i], t[1][i])
PreMulC_end(p, a, c, m, z)
def PreMulC_end(p, a, c, m, z):
@@ -646,7 +645,7 @@ def Mod2(a_0, a, k, kappa, signed):
t[1] = a
adds(t[2], t[0], t[1])
adds(t[3], t[2], r_prime)
asm_open(c, t[3])
asm_open(True, c, t[3])
from . import floatingpoint
c_0 = floatingpoint.bits(c, 1)[0]
mulci(tc, c_0, 2)

View File

@@ -181,7 +181,8 @@ class Compiler:
action="store_true",
dest="invperm",
help="speedup inverse permutation (only use in two-party, "
"semi-honest environment)")
"semi-honest environment)"
)
parser.add_option(
"-C",
"--CISC",
@@ -244,11 +245,9 @@ class Compiler:
self.VARS[op.__name__] = op
# add open and input separately due to name conflict
self.VARS["open"] = instructions.asm_open
self.VARS["vopen"] = instructions.vasm_open
self.VARS["gopen"] = instructions.gasm_open
self.VARS["vgopen"] = instructions.vgasm_open
self.VARS["input"] = instructions.asm_input
self.VARS["ginput"] = instructions.gasm_input
self.VARS["comparison"] = comparison
@@ -268,7 +267,6 @@ class Compiler:
"sgf2nuint",
"sgf2nuint32",
"sgf2nfloat",
"sfloat",
"cfloat",
"squant",
]:
@@ -276,6 +274,9 @@ class Compiler:
def prep_compile(self, name=None):
self.parse_args()
if len(self.args) < 1 and name is None:
self.parser.print_help()
exit(1)
self.build_program(name=name)
self.build_vars()
@@ -372,7 +373,7 @@ class Compiler:
)
self.prep_compile(self.compile_name)
print(
f"Compiling: {self.compile_name} from " f"func {self.compile_func.__name__}"
"Compiling: {} from {}".format(self.compile_name, self.compile_func.__name__)
)
self.compile_function()
self.finalize_compile()

View File

@@ -28,7 +28,7 @@ def shift_two(n, pos):
def maskRing(a, k):
shift = int(program.Program.prog.options.ring) - k
if program.Program.prog.use_edabit:
if program.Program.prog.use_edabit():
r_prime, r = types.sint.get_edabit(k)
elif program.Program.prog.use_dabit:
rr, r = zip(*(types.sint.get_dabit() for i in range(k)))
@@ -36,7 +36,7 @@ def maskRing(a, k):
else:
r = [types.sint.get_random_bit() for i in range(k)]
r_prime = types.sint.bit_compose(r)
c = ((a + r_prime) << shift).reveal() >> shift
c = ((a + r_prime) << shift).reveal(False) >> shift
return c, r
def maskField(a, k, kappa):
@@ -47,7 +47,7 @@ def maskField(a, k, kappa):
comparison.PRandM(r_dprime, r_prime, r, k, k, kappa)
# always signed due to usage in equality testing
a += two_power(k)
asm_open(c, a + two_power(k) * r_dprime + r_prime)
asm_open(True, c, a + two_power(k) * r_dprime + r_prime)
return c, r
@instructions_base.ret_cisc
@@ -233,7 +233,7 @@ def Inv(a):
ldi(one, 1)
inverse(t[0], t[1])
s = t[0]*a
asm_open(c[0], s)
asm_open(True, c[0], s)
# avoid division by zero for benchmarking
divc(c[1], one, c[0])
#divc(c[1], c[0], one)
@@ -281,7 +281,7 @@ def BitDecRingRaw(a, k, m):
else:
r_bits = [types.sint.get_random_bit() for i in range(m)]
r = types.sint.bit_compose(r_bits)
shifted = ((a - r) << n_shift).reveal()
shifted = ((a - r) << n_shift).reveal(False)
masked = shifted >> n_shift
bits = r_bits[0].bit_adder(r_bits, masked.bit_decompose(m))
return bits
@@ -299,7 +299,7 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
r = [types.sint() for i in range(m)]
comparison.PRandM(r_dprime, r_prime, r, k, m, kappa)
pow2 = two_power(k + kappa)
asm_open(c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime)
asm_open(True, c, pow2 + two_power(k) + a - two_power(m)*r_dprime - r_prime)
res = r[0].bit_adder(r, list(r[0].bit_decompose_clear(c,m)))
instructions_base.reset_global_vector_size()
return res
@@ -341,10 +341,10 @@ def B2U_from_Pow2(pow2a, l, kappa):
if program.Program.prog.options.ring:
n_shift = int(program.Program.prog.options.ring) - l
assert n_shift > 0
c = ((pow2a + types.sint.bit_compose(r)) << n_shift).reveal() >> n_shift
c = ((pow2a + types.sint.bit_compose(r)) << n_shift).reveal(False) >> n_shift
else:
comparison.PRandInt(t, kappa)
asm_open(c, pow2a + two_power(l) * t +
asm_open(True, c, pow2a + two_power(l) * t +
sum(two_power(i) * r[i] for i in range(l)))
comparison.program.curr_tape.require_bit_length(l + kappa)
c = list(r_bits[0].bit_decompose_clear(c, l))
@@ -386,11 +386,11 @@ def Trunc(a, l, m, kappa=None, compute_modulo=False, signed=False):
r_dprime += t1 - t2
if program.Program.prog.options.ring:
n_shift = int(program.Program.prog.options.ring) - l
c = ((a + r_dprime + r_prime) << n_shift).reveal() >> n_shift
c = ((a + r_dprime + r_prime) << n_shift).reveal(False) >> n_shift
else:
comparison.PRandInt(rk, kappa)
r_dprime += two_power(l) * rk
asm_open(c, a + r_dprime + r_prime)
asm_open(True, c, a + r_dprime + r_prime)
for i in range(1,l):
ci[i] = c % two_power(i)
c_dprime = sum(ci[i]*(x[i-1] - x[i]) for i in range(1,l))
@@ -416,7 +416,7 @@ def TruncInRing(to_shift, l, pow2m):
rev *= pow2m
r_bits = [types.sint.get_random_bit() for i in range(l)]
r = types.sint.bit_compose(r_bits)
shifted = (rev - (r << n_shift)).reveal()
shifted = (rev - (r << n_shift)).reveal(False)
masked = shifted >> n_shift
bits = types.intbitint.bit_adder(r_bits, masked.bit_decompose(l))
return types.sint.bit_compose(reversed(bits))
@@ -457,7 +457,7 @@ def Int2FL(a, gamma, l, kappa=None):
v = t.right_shift(gamma - l - 1, gamma - 1, kappa, signed=False)
else:
v = 2**(l-gamma+1) * t
p = (p + gamma - 1 - l) * (1 -z)
p = (p + gamma - 1 - l) * z.bit_not()
return v, p, z, s
def FLRound(x, mode):
@@ -530,7 +530,7 @@ def TruncPrRing(a, k, m, signed=True):
msb = r_bits[-1]
n_shift = n_ring - (k + 1)
tmp = a + r
masked = (tmp << n_shift).reveal()
masked = (tmp << n_shift).reveal(False)
shifted = (masked << 1 >> (n_shift + m + 1))
overflow = msb.bit_xor(masked >> (n_ring - 1))
res = shifted - upper + \
@@ -551,7 +551,7 @@ def TruncPrField(a, k, m, kappa=None):
k, m, kappa, use_dabit=False)
two_to_m = two_power(m)
r = two_to_m * r_dprime + r_prime
c = (b + r).reveal()
c = (b + r).reveal(False)
c_prime = c % two_to_m
a_prime = c_prime - r_prime
d = (a - a_prime) / two_to_m
@@ -667,14 +667,14 @@ def BitDecFull(a, n_bits=None, maybe_mixed=False):
def _():
for i in range(bit_length):
tbits[j][i].link(sint.get_random_bit())
c = regint(BITLT(tbits[j], pbits, bit_length).reveal())
c = regint(BITLT(tbits[j], pbits, bit_length).reveal(False))
done[j].link(c)
return (sum(done) != a.size)
for j in range(a.size):
for i in range(bit_length):
movs(bbits[i][j], tbits[j][i])
b = sint.bit_compose(bbits)
c = (a-b).reveal()
c = (a-b).reveal(False)
cmodp = c
t = bbits[0].bit_decompose_clear(p - c, bit_length)
c = longint(c, bit_length)

View File

@@ -387,6 +387,14 @@ class use(base.Instruction):
code = base.opcodes['USE']
arg_format = ['int','int','int']
@classmethod
def get_usage(cls, args):
from .program import field_types, data_types
from .util import find_in_dict
return {(find_in_dict(field_types, args[0].i),
find_in_dict(data_types, args[1].i)):
args[2].i}
class use_inp(base.Instruction):
""" Input usage. Necessary to avoid reusage while using
preprocessing from files.
@@ -398,6 +406,13 @@ class use_inp(base.Instruction):
code = base.opcodes['USE_INP']
arg_format = ['int','int','int']
@classmethod
def get_usage(cls, args):
from .program import field_types, data_types
from .util import find_in_dict
return {(find_in_dict(field_types, args[0].i), 'input', args[1].i):
args[2].i}
class use_edabit(base.Instruction):
""" edaBit usage. Necessary to avoid reusage while using
preprocessing from files. Also used to multithreading for expensive
@@ -410,6 +425,10 @@ class use_edabit(base.Instruction):
code = base.opcodes['USE_EDABIT']
arg_format = ['int','int','int']
@classmethod
def get_usage(cls, args):
return {('sedabit' if args[0].i else 'edabit', args[1].i): args[2].i}
class use_matmul(base.Instruction):
""" Matrix multiplication usage. Used for multithreading of
preprocessing.
@@ -471,6 +490,11 @@ class use_prep(base.Instruction):
code = base.opcodes['USE_PREP']
arg_format = ['str','int']
@classmethod
def get_usage(cls, args):
return {('gf2n' if cls.__name__ == 'guse_prep' else 'modp',
args[0].str): args[1].i}
class nplayers(base.Instruction):
""" Store number of players in clear integer register.
@@ -783,30 +807,6 @@ class gbitcom(base.Instruction):
return True
###
### Special GF(2) arithmetic instructions
###
@base.vectorize
class gmulbitc(base.MulBase):
r""" Clear GF(2^n) by clear GF(2) multiplication """
__slots__ = []
code = base.opcodes['GMULBITC']
arg_format = ['cgw','cg','cg']
def is_gf2n(self):
return True
@base.vectorize
class gmulbitm(base.MulBase):
r""" Secret GF(2^n) by clear GF(2) multiplication """
__slots__ = []
code = base.opcodes['GMULBITM']
arg_format = ['sgw','sg','cg']
def is_gf2n(self):
return True
###
### Arithmetic with immediate values
###
@@ -1707,6 +1707,7 @@ class writesockets(base.IOInstruction):
from registers into a socket for a specified client id. If the
protocol uses MACs, the client should be different for every party.
:param: number of arguments to follow
:param: client id (regint)
:param: message type (must be 0)
:param: vector size (int)
@@ -2162,14 +2163,19 @@ class gconvgf2n(base.Instruction):
class asm_open(base.VarArgsInstruction):
""" Reveal secret registers (vectors) to clear registers (vectors).
:param: number of argument to follow (multiple of two)
:param: number of argument to follow (odd number)
:param: check after opening (0/1)
:param: destination (cint)
:param: source (sint)
:param: (repeat the last two)...
"""
__slots__ = []
code = base.opcodes['OPEN']
arg_format = tools.cycle(['cw','s'])
arg_format = tools.chain(['int'], tools.cycle(['cw','s']))
def merge(self, other):
self.args[0] |= other.args[0]
self.args += other.args[1:]
@base.gf2n
@base.vectorize
@@ -2415,12 +2421,17 @@ class shuffle_base(base.DataInstruction):
def logn(n):
return int(math.ceil(math.log(n, 2)))
@classmethod
def n_swaps(cls, n):
logn = cls.logn(n)
return logn * 2 ** logn - 2 ** logn + 1
def add_gen_usage(self, req_node, n):
# hack for unknown usage
req_node.increment(('bit', 'inverse'), float('inf'))
# minimal usage with two relevant parties
logn = self.logn(n)
n_switches = logn * 2 ** logn
n_switches = self.n_swaps(n)
for i in range(self.n_relevant_parties):
req_node.increment((self.field_type, 'input', i), n_switches)
# multiplications for bit check
@@ -2430,7 +2441,7 @@ class shuffle_base(base.DataInstruction):
def add_apply_usage(self, req_node, n, record_size):
req_node.increment(('bit', 'inverse'), float('inf'))
logn = self.logn(n)
n_switches = logn * 2 ** logn * self.n_relevant_parties
n_switches = self.n_swaps(n) * self.n_relevant_parties
if n != 2 ** logn:
record_size += 1
req_node.increment((self.field_type, 'triple'),
@@ -2548,7 +2559,7 @@ class sqrs(base.CISC):
c = [program.curr_block.new_reg('c') for i in range(2)]
square(s[0], s[1])
subs(s[2], self.args[1], s[0])
asm_open(c[0], s[2])
asm_open(False, c[0], s[2])
mulc(c[1], c[0], c[0])
mulm(s[3], self.args[1], c[0])
adds(s[4], s[3], s[3])

View File

@@ -542,7 +542,7 @@ def cisc(function):
def get_bytes(self):
assert len(self.kwargs) < 2
res = int_to_bytes(opcodes['CISC'])
res = LongArgFormat.encode(opcodes['CISC'])
res += int_to_bytes(sum(len(x[0]) + 2 for x in self.calls) + 1)
name = self.function.__name__
String.check(name)
@@ -720,7 +720,7 @@ class IntArgFormat(ArgFormat):
class LongArgFormat(IntArgFormat):
@classmethod
def encode(cls, arg):
return struct.pack('>Q', arg)
return list(struct.pack('>Q', arg))
def __init__(self, f):
self.i = struct.unpack('>Q', f.read(8))[0]
@@ -741,6 +741,8 @@ class ImmediateGF2NAF(IntArgFormat):
class PlayerNoAF(IntArgFormat):
@classmethod
def check(cls, arg):
if not util.is_constant(arg):
raise CompilerError('Player number must be known at compile time')
super(PlayerNoAF, cls).check(arg)
if arg > 256:
raise ArgumentError(arg, 'Player number > 256')
@@ -823,7 +825,7 @@ class Instruction(object):
return (prefix << self.code_length) + self.code
def get_encoding(self):
enc = int_to_bytes(self.get_code())
enc = LongArgFormat.encode(self.get_code())
# add the number of registers if instruction flagged as has var args
if self.has_var_args():
enc += int_to_bytes(len(self.args))
@@ -958,7 +960,7 @@ class ParsedInstruction:
except AttributeError:
pass
read = lambda: struct.unpack('>I', f.read(4))[0]
full_code = read()
full_code = struct.unpack('>Q', f.read(8))[0]
code = full_code % (1 << Instruction.code_length)
self.size = full_code >> Instruction.code_length
self.type = cls.reverse_opcodes[code]

View File

@@ -243,6 +243,10 @@ def store_in_mem(value, address):
try:
value.store_in_mem(address)
except AttributeError:
if isinstance(value, (list, tuple)):
for i, x in enumerate(value):
store_in_mem(x, address + i)
return
# legacy
if value.is_clear:
if isinstance(address, cint):
@@ -261,11 +265,13 @@ def reveal(secret):
try:
return secret.reveal()
except AttributeError:
if secret.is_clear:
return secret
if secret.is_gf2n:
res = cgf2n()
else:
res = cint()
instructions.asm_open(res, secret)
instructions.asm_open(True, res, secret)
return res
@vectorize
@@ -883,10 +889,10 @@ def range_loop(loop_body, start, stop=None, step=None):
def for_range(start, stop=None, step=None):
"""
Decorator to execute loop bodies consecutively. Arguments work as
in Python :py:func:`range`, but they can by any public
in Python :py:func:`range`, but they can be any public
integer. Information has to be passed out via container types such
as :py:class:`~Compiler.types.Array` or declaring registers as
:py:obj:`global`. Note that changing Python data structures such
as :py:class:`~Compiler.types.Array` or using :py:func:`update`.
Note that changing Python data structures such
as lists within the loop is not possible, but the compiler cannot
warn about this.
@@ -901,13 +907,11 @@ def for_range(start, stop=None, step=None):
@for_range(n)
def _(i):
a[i] = i
global x
x += 1
x.update(x + 1)
Note that you cannot overwrite data structures such as
:py:class:`~Compiler.types.Array` in a loop even when using
:py:obj:`global`. Use :py:func:`~Compiler.types.Array.assign`
instead.
:py:class:`~Compiler.types.Array` in a loop. Use
:py:func:`~Compiler.types.Array.assign` instead.
"""
def decorator(loop_body):
range_loop(loop_body, start, stop, step)
@@ -1518,6 +1522,11 @@ def if_then(condition):
state = State()
if callable(condition):
condition = condition()
try:
if not condition.is_clear:
raise CompilerError('cannot branch on secret values')
except AttributeError:
pass
state.condition = regint.conv(condition)
state.start_block = instructions.program.curr_block
state.req_child = get_tape().open_scope(lambda x: x[0].max(x[1]), \
@@ -1889,7 +1898,7 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
theta = int(ceil(log(k/3.5) / log(2)))
base.set_global_vector_size(b.size)
alpha = b.get_type(2 * k).two_power(2*f)
alpha = b.get_type(2 * k).two_power(2*f, size=b.size)
w = AppRcr(b, k, f, kappa, simplex_flag, nearest).extend(2 * k)
x = alpha - b.extend(2 * k) * w
base.reset_global_vector_size()

View File

@@ -148,7 +148,7 @@ def argmax(x):
""" Compute index of maximum element.
:param x: iterable
:returns: sint
:returns: sint or 0 if :py:obj:`x` has length 1
"""
def op(a, b):
comp = (a[1] > b[1])
@@ -164,7 +164,7 @@ def softmax(x):
return softmax_from_exp(exp_for_softmax(x)[0])
def exp_for_softmax(x):
m = util.max(x)
m = util.max(x) - get_limit(x[0]) + 1 + math.log(len(x), 2)
mv = m.expand_to_vector(len(x))
try:
x = x.get_vector()
@@ -2384,6 +2384,11 @@ class Optimizer:
for layer in self.layers:
layer.output_weights()
def summary(self):
sizes = [var.total_size() for var in self.thetas]
print(sizes)
print('Trainable params:', sum(sizes))
class Adam(Optimizer):
""" Adam/AMSgrad optimizer.
@@ -2653,9 +2658,7 @@ class keras:
return list(self.opt.thetas)
def summary(self):
sizes = [var.total_size() for var in self.trainable_variables]
print(sizes)
print('Trainable params:', sum(sizes))
self.opt.summary()
def build(self, input_shape, batch_size=128):
data_input_shape = input_shape

View File

@@ -295,7 +295,6 @@ def exp2_fx(a, zero_output=False, as19=False):
intbitint = types.intbitint
n_shift = int(types.program.options.ring) - a.k
if types.program.use_split():
assert not zero_output
from Compiler.GC.types import sbitvec
if types.program.use_split() == 3:
x = a.v.split_to_two_summands(a.k)
@@ -327,6 +326,7 @@ def exp2_fx(a, zero_output=False, as19=False):
s = sint.conv(bits[-1])
lower = sint.bit_compose(sint.conv(b) for b in bits[:a.f])
higher_bits = bits[a.f:n_bits]
bits_to_check = bits[n_bits:-1]
else:
if types.program.use_edabit():
l = sint.get_edabit(a.f, True)
@@ -338,7 +338,7 @@ def exp2_fx(a, zero_output=False, as19=False):
r_bits = [sint.get_random_bit() for i in range(a.k)]
r = sint.bit_compose(r_bits)
lower_r = sint.bit_compose(r_bits[:a.f])
shifted = ((a.v - r) << n_shift).reveal()
shifted = ((a.v - r) << n_shift).reveal(False)
masked_bits = (shifted >> n_shift).bit_decompose(a.k)
lower_overflow = comparison.CarryOutRaw(masked_bits[a.f-1::-1],
r_bits[a.f-1::-1])

View File

@@ -545,7 +545,7 @@ class Program(object):
"""
if change is None:
if not self._invperm:
self.relevant_opts.add('invperm')
self.relevant_opts.add("invperm")
return self._invperm
else:
self._invperm = change
@@ -1276,7 +1276,7 @@ class Tape:
"can_eliminate",
"duplicates",
]
maximum_size = 2 ** (32 - inst_base.Instruction.code_length) - 1
maximum_size = 2 ** (64 - inst_base.Instruction.code_length) - 1
def __init__(self, reg_type, program, size=None, i=None):
"""Creates a new register.
@@ -1382,6 +1382,20 @@ class Tape:
for dup in self.duplicates:
dup.duplicates = self.duplicates
def update(self, other):
"""
Update register. Useful in loops like
:py:func:`~Compiler.library.for_range`.
:param other: any convertible type
"""
other = self.conv(other)
if self.program != other.program:
raise CompilerError(
'cannot update register with one from another thread')
self.link(other)
@property
def is_gf2n(self):
return (

View File

@@ -127,8 +127,14 @@ def vectorize(operation):
if (isinstance(args[0], Tape.Register) or isinstance(args[0], sfloat)) \
and not isinstance(args[0], bits) \
and args[0].size != self.size:
raise VectorMismatch('Different vector sizes of operands: %d/%d'
% (self.size, args[0].size))
if min(args[0].size, self.size) == 1:
size = max(args[0].size, self.size)
self = self.expand_to_vector(size)
args = list(args)
args[0] = args[0].expand_to_vector(size)
else:
raise VectorMismatch('Different vector sizes of operands: %d/%d'
% (self.size, args[0].size))
set_global_vector_size(self.size)
try:
res = operation(self, *args, **kwargs)
@@ -249,8 +255,11 @@ class _number(Tape._no_truth):
try:
return self.mul(other)
except VectorMismatch:
# try reverse multiplication
return NotImplemented
if type(self) != type(other) and 1 in (self.size, other.size):
# try reverse multiplication
return NotImplemented
else:
raise
__radd__ = __add__
__rmul__ = __mul__
@@ -1658,6 +1667,8 @@ class regint(_register, _int):
"""
if player == None:
player = -1
if not util.is_constant(player):
raise CompilerError('Player number must be known at compile time')
intoutput(player, self)
class localint(Tape._no_truth):
@@ -2081,12 +2092,15 @@ class _secret(_register, _secret_structure):
return self.secret_op(other, subs, submr, subsfi, True)
__rsub__.__doc__ = __sub__.__doc__
@vectorize
def __truediv__(self, other):
""" Secret field division.
:param other: any compatible type """
return self * (self.clear_type(1) / other)
try:
one = self.clear_type(1, size=other.size)
except AttributeError:
one = self.clear_type(1)
return self * (one / other)
@vectorize
def __rtruediv__(self, other):
@@ -2113,12 +2127,12 @@ class _secret(_register, _secret_structure):
@set_instruction_type
@vectorize
def reveal(self):
def reveal(self, check=True):
""" Reveal secret value publicly.
:rtype: relevant clear type """
res = self.clear_type()
asm_open(res, self)
asm_open(check, res, self)
return res
@set_instruction_type
@@ -2166,9 +2180,7 @@ class sint(_secret, _int):
signed integer in a restricted range, see below. The same holds
for ``abs()``, shift operators (``<<, >>``), modulo (``%``), and
exponentation (``**``). Modulo only works if the right-hand
operator is a compile-time power of two, and exponentiation only
works if the base is two or if the exponent is a compile-time
integer.
operator is a compile-time power of two.
Most non-linear operations require compile-time parameters for bit
length and statistical security. They default to the global
@@ -2672,7 +2684,7 @@ class sint(_secret, _int):
return comparison.TruncZeros(self, bit_length, n_zeros, signed)
@staticmethod
def two_power(n):
def two_power(n, size=None):
return floatingpoint.two_power(n)
def split_to_n_summands(self, length, n):
@@ -2690,7 +2702,6 @@ class sint(_secret, _int):
columns = self.split_to_n_summands(length, n)
return _bitint.wallace_tree_without_finish(columns, get_carry)
@vectorize
def reveal_to(self, player):
""" Reveal secret value to :py:obj:`player`.
@@ -2698,13 +2709,14 @@ class sint(_secret, _int):
:returns: :py:class:`personal`
"""
if not util.is_constant(player):
secret_mask = sint()
player_mask = cint()
inputmaskreg(secret_mask, player_mask, regint.conv(player))
secret_mask = sint(size=self.size)
player_mask = cint(size=self.size)
inputmaskreg(secret_mask, player_mask,
regint.conv(player).expand_to_vector(self.size))
return personal(player,
(self + secret_mask).reveal() - player_mask)
(self + secret_mask).reveal(False) - player_mask)
else:
res = personal(player, self.clear_type())
res = personal(player, self.clear_type(size=self.size))
privateoutput(self.size, player, res._v, self)
return res
@@ -2856,6 +2868,10 @@ class sintbit(sint):
else:
return super(sintbit, self).__rsub__(other)
__rand__ = __and__
__rxor__ = __xor__
__ror__ = __or__
class sgf2n(_secret, _gf2n):
"""
Secret :math:`\mathrm{GF}(2^n)` value. n is chosen at runtime. A
@@ -2873,6 +2889,7 @@ class sgf2n(_secret, _gf2n):
instruction_type = 'gf2n'
clear_type = cgf2n
reg_type = 'sg'
long_one = staticmethod(lambda: 1)
@classmethod
def get_type(cls, length):
@@ -3022,6 +3039,7 @@ class _bitint(Tape._no_truth):
bits = None
log_rounds = False
linear_rounds = False
comp_result = staticmethod(lambda x: x)
@staticmethod
def half_adder(a, b):
@@ -3241,12 +3259,16 @@ class _bitint(Tape._no_truth):
del carries[-1]
return sums, carries
def expand(self, other):
a = self.bit_decompose()
b = util.bit_decompose(other, self.n_bits)
return a, b
def __sub__(self, other):
if type(other) == sgf2n:
raise CompilerError('Unclear subtraction')
a = self.bit_decompose()
b = util.bit_decompose(other, self.n_bits)
from util import bit_not, bit_and, bit_xor
a, b = self.expand(other)
n = 1
for x in (a + b):
try:
@@ -3293,8 +3315,7 @@ class _bitint(Tape._no_truth):
a[-1], b[-1] = b[-1], a[-1]
def comparison(self, other, const_rounds=False, index=None):
a = self.bit_decompose()
b = util.bit_decompose(other, self.n_bits)
a, b = self.expand(other)
self.prep_comparison(a, b)
if const_rounds:
return self.get_highest_different_bits(a, b, index)
@@ -3304,30 +3325,33 @@ class _bitint(Tape._no_truth):
def __lt__(self, other):
if program.options.comparison == 'log':
x, not_equal = self.comparison(other)
return util.if_else(not_equal, x, 0)
res = util.if_else(not_equal, x, 0)
else:
return self.comparison(other, True, 1)
res = self.comparison(other, True, 1)
return self.comp_result(res)
def __le__(self, other):
if program.options.comparison == 'log':
x, not_equal = self.comparison(other)
return util.if_else(not_equal, x, 1)
res = util.if_else(not_equal, x, x.long_one())
else:
return 1 - self.comparison(other, True, 0)
res = self.comparison(other, True, 0).bit_not()
return self.comp_result(res)
def __ge__(self, other):
return 1 - (self < other)
return (self < other).bit_not()
def __gt__(self, other):
return 1 - (self <= other)
return (self <= other).bit_not()
def __eq__(self, other, bit_length=None, security=None):
diff = self ^ other
diff_bits = [1 - x for x in diff.bit_decompose()[:bit_length]]
return floatingpoint.KMul(diff_bits)
diff_bits = [x.bit_not() for x in diff.bit_decompose()[:bit_length]]
return self.comp_result(util.tree_reduce(lambda x, y: x.bit_and(y),
diff_bits))
def __ne__(self, other):
return 1 - (self == other)
return (self == other).bit_not()
equal = __eq__
@@ -3881,7 +3905,6 @@ class cfix(_number, _structure):
def output_if(self, cond):
cond_print_plain(cint.conv(cond), self.v, cint(-self.f, size=self.size))
@vectorize
def binary_output(self, player=None):
""" Write double-precision floating-point number to
``Player-Data/Binary-Output-P<playerno>-<threadno>``.
@@ -3890,7 +3913,11 @@ class cfix(_number, _structure):
"""
if player == None:
player = -1
if not util.is_constant(player):
raise CompilerError('Player number must be known at compile time')
set_global_vector_size(self.size)
floatoutput(player, self.v, cint(-self.f), cint(0), cint(0))
reset_global_vector_size()
class _single(_number, _secret_structure):
""" Representation as single integer preserving the order """
@@ -4124,6 +4151,7 @@ class _single(_number, _secret_structure):
class _fix(_single):
""" Secret fixed point type. """
__slots__ = ['v', 'f', 'k']
is_clear = False
def set_precision(cls, f, k = None):
cls.f = f
@@ -4349,6 +4377,18 @@ class _fix(_single):
""" Bit decomposition. """
return self.v.bit_decompose(n_bits or self.k)
def update(self, other):
"""
Update register. Useful in loops like
:py:func:`~Compiler.library.for_range`.
:param other: any convertible type
"""
other = self.conv(other)
assert self.f == other.f
self.v.update(other.v)
class sfix(_fix):
""" Secret fixed-point number represented as secret integer, by
multiplying with ``2^f`` and then rounding. See :py:class:`sint`
@@ -4737,6 +4777,8 @@ class sfloat(_number, _secret_structure):
returning :py:class:`sint`. The other operand can be any of
sint/cfix/regint/cint/int/float.
This data type only works with arithmetic computation.
:param v: initialization (sfloat/sfix/float/int/sint/cint/regint)
"""
__slots__ = ['v', 'p', 'z', 's', 'size']
@@ -4835,6 +4877,9 @@ class sfloat(_number, _secret_structure):
@vectorize_init
@read_mem_value
def __init__(self, v, p=None, z=None, s=None, size=None):
if program.options.binary:
raise CompilerError(
'floating-point operations not supported with binary circuits')
self.size = get_global_vector_size()
if p is None:
if isinstance(v, sfloat):
@@ -5227,7 +5272,13 @@ class Array(_vectorizable):
def create_from(cls, l):
""" Convert Python iterator or vector to array. Basic type will be taken
from first element, further elements must to be convertible to
that. """
that.
:param l: Python iterable or register vector
:returns: :py:class:`Array` of appropriate type containing the contents
of :py:obj:`l`
"""
if isinstance(l, cls):
return l
if isinstance(l, _number):
@@ -6099,12 +6150,12 @@ class SubMultiArray(_vectorizable):
try:
res_matrix[i] = self.value_type.row_matrix_mul(
self[i], other, res_params)
except AttributeError:
except (AttributeError, CompilerError):
# fallback for binary circuits
@library.for_range(other.sizes[1])
@library.for_range_opt(other.sizes[1])
def _(j):
res_matrix[i][j] = 0
@library.for_range(self.sizes[1])
@library.for_range_opt(self.sizes[1])
def _(k):
res_matrix[i][j] += self[i][k] * other[k][j]
return res_matrix
@@ -6223,13 +6274,7 @@ class SubMultiArray(_vectorizable):
res[i] = self.direct_mul_trans(other, indices=indices)
def direct_mul_to_matrix(self, other):
""" Matrix multiplication in the virtual machine.
:param self: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
:param other: :py:class:`Matrix` / 2-dimensional :py:class:`MultiArray`
:returns: :py:obj:`Matrix`
"""
# Obsolete. Use dot().
res = self.value_type.Matrix(self.sizes[0], other.sizes[1])
res.assign_vector(self.direct_mul(other))
return res

View File

@@ -238,6 +238,9 @@ def mem_size(x):
except AttributeError:
return 1
def find_in_dict(d, v):
return list(d.keys())[list(d.values()).index(v)]
class set_by_id(object):
def __init__(self, init=[]):
self.content = {}

View File

@@ -22,7 +22,7 @@ private:
EC_POINT* point;
public:
typedef void next;
typedef P256Element next;
typedef void Square;
static const true_type invertible;

View File

@@ -45,12 +45,13 @@ int main(int argc, const char** argv)
string prefix = get_prep_sub_dir<pShare>(PREP_DIR "ECDSA/", 2);
read_mac_key(prefix, N, keyp);
pShare::MAC_Check::setup(P);
Share<P256Element>::MAC_Check::setup(P);
DataPositions usage;
Sub_Data_Files<pShare> prep(N, prefix, usage);
typename pShare::Direct_MC MCp(keyp);
ArithmeticProcessor _({}, 0);
BaseMachine machine;
machine.ot_setups.push_back({P, false});
SubProcessor<pShare> proc(_, MCp, prep, P);
pShare sk, __;
@@ -60,4 +61,7 @@ int main(int argc, const char** argv)
preprocessing(tuples, n_tuples, sk, proc, opts);
check(tuples, sk, keyp, P);
sign_benchmark(tuples, sk, MCp, P, opts);
pShare::MAC_Check::teardown();
Share<P256Element>::MAC_Check::teardown();
}

View File

@@ -92,9 +92,6 @@ void run(int argc, const char** argv)
P256Element::init();
P256Element::Scalar::next::init_field(P256Element::Scalar::pr(), false);
BaseMachine machine;
machine.ot_setups.push_back({P, true});
P256Element::Scalar keyp;
SeededPRNG G;
keyp.randomize(G);
@@ -102,6 +99,9 @@ void run(int argc, const char** argv)
typedef T<P256Element::Scalar> pShare;
DataPositions usage;
pShare::MAC_Check::setup(P);
T<P256Element>::MAC_Check::setup(P);
OnlineOptions::singleton.batch_size = 1;
typename pShare::Direct_MC MCp(keyp);
ArithmeticProcessor _({}, 0);
@@ -137,4 +137,7 @@ void run(int argc, const char** argv)
preprocessing(tuples, n_tuples, sk, proc, opts);
//check(tuples, sk, keyp, P);
sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc);
pShare::MAC_Check::teardown();
T<P256Element>::MAC_Check::teardown();
}

View File

@@ -130,7 +130,7 @@ void Ciphertext::rerandomize(const FHE_PK& pk)
assert(p != 0);
for (auto& x : r)
{
G.get<FFT_Data::S>(x, params->p0().numBits() - p.numBits() - 1);
G.get(x, params->p0().numBits() - p.numBits() - 1);
x *= p;
}
tmp.from(r, 0);

View File

@@ -368,7 +368,8 @@ ZZX Cyclotomic(int N)
int phi_N(int N)
{
if (((N - 1) & N) != 0)
throw runtime_error("compile with NTL support");
throw runtime_error(
"compile with NTL support (USE_NTL=1 in CONFIG.mine)");
else if (N == 1)
return 1;
else
@@ -418,7 +419,8 @@ void init(Ring& Rg, int m, bool generate_poly)
for (int i=0; i<Rg.phim+1; i++)
{ Rg.poly[i]=to_int(coeff(P,i)); }
#else
throw runtime_error("compile with NTL support");
throw runtime_error(
"compile with NTL support (USE_NTL=1 in CONFIG.mine)");
#endif
}
}

View File

@@ -40,10 +40,9 @@ void PPData::to_eval(vector<modp>& elem) const
*/
}
void PPData::from_eval(vector<modp>& elem) const
void PPData::from_eval(vector<modp>&) const
{
// avoid warning
elem.empty();
throw not_implemented();
/*

View File

@@ -17,15 +17,13 @@ PairwiseMachine::PairwiseMachine(Player& P) :
{
}
PairwiseMachine::PairwiseMachine(int argc, const char** argv) :
MachineBase(argc, argv), P(*new PlainPlayer(N, "pairwise")),
other_pks(N.num_players(), {setup_p.params, 0}),
pk(other_pks[N.my_num()]), sk(pk)
RealPairwiseMachine::RealPairwiseMachine(int argc, const char** argv) :
MachineBase(argc, argv), PairwiseMachine(*new PlainPlayer(N, "pairwise"))
{
init();
}
void PairwiseMachine::init()
void RealPairwiseMachine::init()
{
if (use_gf2n)
{
@@ -63,7 +61,7 @@ PairwiseSetup<P2Data>& PairwiseMachine::setup()
}
template <class FD>
void PairwiseMachine::setup_keys()
void RealPairwiseMachine::setup_keys()
{
auto& N = P;
PairwiseSetup<FD>& s = setup<FD>();
@@ -84,10 +82,11 @@ void PairwiseMachine::setup_keys()
if (i != N.my_num())
other_pks[i].unpack(os[i]);
set_mac_key(s.alphai);
Share<typename FD::T>::MAC_Check::setup(P);
}
template <class T>
void PairwiseMachine::set_mac_key(T alphai)
void RealPairwiseMachine::set_mac_key(T alphai)
{
typedef typename T::FD FD;
auto& N = P;
@@ -142,5 +141,5 @@ void PairwiseMachine::check(Player& P) const
bundle.compare(P);
}
template void PairwiseMachine::setup_keys<FFT_Data>();
template void PairwiseMachine::setup_keys<P2Data>();
template void RealPairwiseMachine::setup_keys<FFT_Data>();
template void RealPairwiseMachine::setup_keys<P2Data>();

View File

@@ -10,7 +10,7 @@
#include "FHEOffline/SimpleMachine.h"
#include "FHEOffline/PairwiseSetup.h"
class PairwiseMachine : public MachineBase
class PairwiseMachine : public virtual MachineBase
{
public:
PairwiseSetup<FFT_Data> setup_p;
@@ -23,15 +23,6 @@ public:
vector<Ciphertext> enc_alphas;
PairwiseMachine(Player& P);
PairwiseMachine(int argc, const char** argv);
void init();
template <class FD>
void setup_keys();
template <class T>
void set_mac_key(T alphai);
template <class FD>
PairwiseSetup<FD>& setup();
@@ -42,4 +33,18 @@ public:
void check(Player& P) const;
};
class RealPairwiseMachine : public virtual MachineBase, public virtual PairwiseMachine
{
public:
RealPairwiseMachine(int argc, const char** argv);
void init();
template <class FD>
void setup_keys();
template <class T>
void set_mac_key(T alphai);
};
#endif /* FHEOFFLINE_PAIRWISEMACHINE_H_ */

View File

@@ -12,7 +12,7 @@
template <template <class> class T, class FD>
SimpleGenerator<T,FD>::SimpleGenerator(const Names& N, const PartSetup<FD>& setup,
const MultiplicativeMachine& machine,
const MultiplicativeMachineParams& machine,
int thread_num, Dtype data_type, Player* player) :
GeneratorBase(thread_num, N, player),
setup(setup), machine(machine),

View File

@@ -14,7 +14,7 @@
#include "Processor/Data_Files.h"
class SimpleMachine;
class MultiplicativeMachine;
class MultiplicativeMachineParams;
class GeneratorBase
{
@@ -53,7 +53,7 @@ template <template <class FD> class T, class FD>
class SimpleGenerator : public GeneratorBase
{
const PartSetup<FD>& setup;
const MultiplicativeMachine& machine;
const MultiplicativeMachineParams& machine;
size_t volatile_memory;
@@ -63,7 +63,7 @@ public:
Producer<FD>* producer;
SimpleGenerator(const Names& N, const PartSetup<FD>& setup,
const MultiplicativeMachine& machine, int thread_num,
const MultiplicativeMachineParams& machine, int thread_num,
Dtype data_type = DATA_TRIPLE, Player* player = 0);
~SimpleGenerator();

View File

@@ -16,6 +16,7 @@
#include "Protocols/fake-stuff.hpp"
#include "Protocols/mac_key.hpp"
#include "Protocols/Share.hpp"
#include "Protocols/MAC_Check.hpp"
#include "Math/modp.hpp"
void* run_generator(void* generator)
@@ -101,6 +102,19 @@ void MachineBase::parse_options(int argc, const char** argv)
start_networking_with_server(hostname, portnum_base);
}
MultiplicativeMachine::MultiplicativeMachine() :
P(N, "machine-coordinator")
{
Share<gfp>::MAC_Check::setup(P);
Share<gf2n_short>::MAC_Check::setup(P);
}
MultiplicativeMachine::~MultiplicativeMachine()
{
Share<gfp>::MAC_Check::teardown();
Share<gf2n_short>::MAC_Check::teardown();
}
void MultiplicativeMachine::parse_options(int argc, const char** argv)
{
opt.add(

View File

@@ -58,8 +58,20 @@ public:
void check(Player&) const {}
};
class MultiplicativeMachine : public MachineBase
class MultiplicativeMachineParams : public MachineBase
{
public:
DataSetup setup;
virtual ~MultiplicativeMachineParams() {}
virtual int get_covert() const { return 0; }
};
class MultiplicativeMachine : public MultiplicativeMachineParams
{
PlainPlayer P;
protected:
void parse_options(int argc, const char** argv);
@@ -68,11 +80,8 @@ protected:
void fake_keys(int slack);
public:
DataSetup setup;
virtual ~MultiplicativeMachine() {}
virtual int get_covert() const { return 0; }
MultiplicativeMachine();
virtual ~MultiplicativeMachine();
};
class SimpleMachine : public MultiplicativeMachine

View File

@@ -37,11 +37,6 @@ public:
return new MAC_Check;
}
static This new_reg()
{
return {};
}
AtlasShare()
{
}

View File

@@ -45,11 +45,6 @@ public:
return new MAC_Check;
}
static This new_reg()
{
return {};
}
CcdShare()
{
}

View File

@@ -93,7 +93,8 @@ void FakeSecret::inputbvec(Processor<FakeSecret>& processor,
{
Input input;
input.reset_all(*ShareThread<FakeSecret>::s().P);
processor.inputbvec(input, input_processor, args, 0);
processor.inputbvec(input, input_processor, args,
*ShareThread<FakeSecret>::s().P);
}
void FakeSecret::and_(int n, const FakeSecret& x, const FakeSecret& y,

View File

@@ -49,11 +49,6 @@ public:
return new MAC_Check;
}
static This new_reg()
{
return {};
}
MaliciousCcdShare()
{
}

View File

@@ -96,7 +96,7 @@ public:
void inputb(typename T::Input& input, ProcessorBase& input_processor,
const vector<int>& args, int my_num);
void inputbvec(typename T::Input& input, ProcessorBase& input_processor,
const vector<int>& args, int my_num);
const vector<int>& args, PlayerBase& P);
void reveal(const vector<int>& args);

View File

@@ -129,7 +129,7 @@ Secret<T>::Secret()
template<class T>
T& GC::Secret<T>::get_new_reg()
{
registers.push_back(T::new_reg());
registers.push_back(T());
T& res = registers.back();
#ifdef DEBUG_REGS
cout << "Secret: new " << typeid(T).name() << " " << res.get_id() << " at " << &res << endl;

View File

@@ -20,7 +20,7 @@ namespace GC {
template <class T>
inline T XOR(const T& left, const T& right)
{
T res(T::new_reg());
T res;
XOR<T>(res, left, right);
return res;
}
@@ -37,7 +37,7 @@ inline void AND(T& res, const T& left, const T& right)
template <class T>
inline T AND(const T& left, const T& right)
{
T res = T::new_reg();
T res;
AND<T>(res, left, right);
return res;
}
@@ -66,7 +66,7 @@ template <class T>
void Secret<T>::invert(int n, const Secret<T>& x)
{
resize_regs(n);
T one = T::new_reg();
T one;
one.public_input(1);
for (int i = 0; i < n; i++)
{
@@ -87,7 +87,7 @@ template <class T>
inline void Secret<T>::resize_regs(size_t n)
{
if (registers.size() != n)
registers.resize(n, T::new_reg());
registers.resize(n);
}
} /* namespace GC */

View File

@@ -104,18 +104,11 @@ ShareParty<T>::ShareParty(int argc, const char** argv, ez::ezOptionParser& opt,
network_opts.start_networking(this->N, my_num);
if (online_opts.live_prep)
if (T::needs_ot)
{
Player* P;
if (this->machine.use_encryption)
P = new CryptoPlayer(this->N, "shareparty");
else
P = new PlainPlayer(this->N, "shareparty");
for (int i = 0; i < this->machine.nthreads; i++)
this->machine.ot_setups.push_back({*P, true});
delete P;
}
Player* P;
if (this->machine.use_encryption)
P = new CryptoPlayer(this->N, "shareparty");
else
P = new PlainPlayer(this->N, "shareparty");
try
{
@@ -130,8 +123,13 @@ ShareParty<T>::ShareParty(int argc, const char** argv, ez::ezOptionParser& opt,
this->mac_key.randomize(G);
}
T::MC::setup(*P);
this->run();
T::MC::teardown();
delete P;
this->machine.write_memory(this->N.my_num());
}

View File

@@ -141,7 +141,7 @@ void ShareSecret<U>::inputbvec(Processor<U>& processor,
auto& party = ShareThread<U>::s();
typename U::Input input(*party.MC, party.DataF, *party.P);
input.reset_all(*party.P);
processor.inputbvec(input, input_processor, args, party.P->my_num());
processor.inputbvec(input, input_processor, args, *party.P);
}
template <class T>
@@ -192,14 +192,18 @@ void Processor<T>::inputb(typename T::Input& input, ProcessorBase& input_process
template <class T>
void Processor<T>::inputbvec(typename T::Input& input, ProcessorBase& input_processor,
const vector<int>& args, int my_num)
const vector<int>& args, PlayerBase& P)
{
int my_num = P.my_num();
InputVecArgList a(args);
complexity += a.n_input_bits();
bool interactive = a.n_interactive_inputs_from_me(my_num) > 0;
for (auto x : a)
{
if (unsigned(x.from) >= unsigned(P.num_players()))
throw runtime_error("invalid player number");
if (x.from == my_num)
{
bigint whole_input = get_long_input<bigint>(x.params,
@@ -237,6 +241,7 @@ void ShareSecret<U>::reveal_inst(Processor<U>& processor,
const vector<int>& args)
{
auto& party = ShareThread<U>::s();
party.check();
assert(args.size() % 3 == 0);
vector<U> shares;
for (size_t i = 0; i < args.size(); i += 3)

View File

@@ -42,6 +42,7 @@ public:
void pre_run(Player& P, typename T::mac_key_type mac_key);
void post_run();
void check();
void and_(Processor<T>& processor, const vector<int>& args, bool repeat);
void xors(Processor<T>& processor, const vector<int>& args);

View File

@@ -75,6 +75,12 @@ void StandaloneShareThread<T>::pre_run()
template<class T>
void ShareThread<T>::post_run()
{
check();
}
template<class T>
void ShareThread<T>::check()
{
protocol->check();
MC->Check(*this->P);

View File

@@ -69,10 +69,6 @@ void ThreadMaster<T>::run()
machine.load_schedule(progname);
if (T::needs_ot)
for (int i = 0; i < machine.nthreads; i++)
machine.ot_setups.push_back({*P, true});
for (int i = 0; i < machine.nthreads; i++)
threads.push_back(new_thread(i));
for (auto thread : threads)
@@ -101,13 +97,15 @@ void ThreadMaster<T>::run()
delete thread;
}
delete P;
exe_stats.print();
stats.print();
cerr << "Time = " << timer.elapsed() << " seconds" << endl;
cerr << "Data sent = " << stats.sent * 1e-6 << " MB" << endl;
machine.print_global_comm(*P, stats);
delete P;
}
} /* namespace GC */

View File

@@ -10,6 +10,20 @@
#include "Protocols/Share.h"
#include "Math/Bit.h"
class gf2n_mac_key : public gf2n_short
{
public:
gf2n_mac_key()
{
}
template<class T>
gf2n_mac_key(const T& other) :
gf2n_short(other)
{
}
};
namespace GC
{
@@ -29,6 +43,9 @@ public:
typedef T mac_type;
typedef T sacri_type;
typedef Share<T> input_check_type;
typedef This prep_type;
typedef This prep_check_type;
typedef This bit_prep_type;
typedef MAC_Check_<This> MAC_Check;
typedef TinierSharePrep<This> LivePrep;
@@ -70,11 +87,6 @@ public:
return new MAC_Check(mac_key);
}
static This new_reg()
{
return {};
}
TinierShare()
{
}

View File

@@ -18,6 +18,16 @@ class TinyMC : public MAC_Check_Base<T>
PointerVector<int> sizes;
public:
static void setup(Player& P)
{
T::part_type::MAC_Check::setup(P);
}
static void teardown()
{
T::part_type::MAC_Check::teardown();
}
TinyMC(typename T::mac_key_type mac_key) :
part_MC(mac_key)
{

View File

@@ -46,11 +46,6 @@ public:
return "tiny share";
}
static This new_reg()
{
return {};
}
TinyShare()
{
}

View File

@@ -1,5 +0,0 @@
192.168.0.1
192.168.0.2
192.168.0.3
192.168.0.4
192.168.0.5

View File

@@ -302,7 +302,7 @@ void OTMachine::run()
// copy base inputs/outputs for each thread
vector<BitVector> base_receiver_input_copy(nthreads);
vector<vector< vector<BitVector> > > base_sender_inputs_copy(nthreads, vector<vector<BitVector> >(nbase, vector<BitVector>(2)));
vector<vector< array<BitVector, 2> > > base_sender_inputs_copy(nthreads, vector<array<BitVector, 2> >(nbase));
vector< vector<BitVector> > base_receiver_outputs_copy(nthreads, vector<BitVector>(nbase));
vector<TwoPartyPlayer*> players(nthreads);

View File

@@ -15,9 +15,9 @@
#include "GC/CcdPrep.hpp"
#include "GC/PersonalPrep.hpp"
//template class GC::ShareParty<GC::TinierSecret<gf2n_short>>;
template class GC::CcdPrep<GC::TinierSecret<gf2n_short>>;
template class Preprocessing<GC::TinierSecret<gf2n_short>>;
template class GC::TinierSharePrep<GC::TinierShare<gf2n_short>>;
template class GC::ShareSecret<GC::TinierSecret<gf2n_short>>;
template class TripleShuffleSacrifice<GC::TinierSecret<gf2n_short>>;
//template class GC::ShareParty<GC::TinierSecret<gf2n_mac_key>>;
template class GC::CcdPrep<GC::TinierSecret<gf2n_mac_key>>;
template class Preprocessing<GC::TinierSecret<gf2n_mac_key>>;
template class GC::TinierSharePrep<GC::TinierShare<gf2n_mac_key>>;
template class GC::ShareSecret<GC::TinierSecret<gf2n_mac_key>>;
template class TripleShuffleSacrifice<GC::TinierSecret<gf2n_mac_key>>;

View File

@@ -32,7 +32,8 @@ void* run_ngenerator_thread(void* ptr)
}
TripleMachine::TripleMachine(int argc, const char** argv) :
nConnections(1), bonding(0)
nConnections(1), player(0), bonding(0),
network_opts(opt, argc, argv, 2, true)
{
opt.add(
"1", // Default.
@@ -66,7 +67,7 @@ TripleMachine::TripleMachine(int argc, const char** argv) :
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Check triples (implies -m).", // Help description.
"Check triples (implies -m; always done with SPDZ2k).", // Help description.
"-c", // Flag token.
"--check" // Flag token.
);
@@ -133,6 +134,7 @@ TripleMachine::TripleMachine(int argc, const char** argv) :
bonding = opt.get("-b")->isSet;
opt.get("-Z")->getInt(z2k);
check |= z2k;
amplify &= not z2k;
z2s = z2k;
if (opt.isSet("-S"))
opt.get("-S")->getInt(z2s);
@@ -156,6 +158,9 @@ template<class T>
GeneratorThread* TripleMachine::new_generator(OTTripleSetup& setup, int i,
typename T::mac_key_type mac_key)
{
if (i == 0)
T::MAC_Check::setup(*player);
if (output and i == 0)
{
prep_data_dir = get_prep_sub_dir<T>(PREP_DIR, nplayers);
@@ -170,16 +175,23 @@ GeneratorThread* TripleMachine::new_generator(OTTripleSetup& setup, int i,
void TripleMachine::run()
{
cout << "my_num: " << my_num << endl;
N[0].init(my_num, 10000, "HOSTS", nplayers);
network_opts.start_networking(N[0], my_num);
nConnections = 1;
if (bonding)
{
N[1].init(my_num, 11000, "HOSTS2", nplayers);
if (network_opts.ip_filename.empty())
{
cerr << "Bonding only works with --ip-file-name" << endl;
exit(1);
}
N[1].init(my_num, network_opts.portnum_base + 1000,
network_opts.ip_filename + "2", nplayers);
nConnections = 2;
}
// do the base OTs
PlainPlayer P(N[0], "base");
OTTripleSetup setup(P, true);
player = &P;
vector<GeneratorThread*> generators(nthreads);
vector<pthread_t> threads(nthreads);

View File

@@ -11,15 +11,33 @@
#include "SPDZ.hpp"
#include "Math/gfp.hpp"
#ifndef N_MAMA_MACS
#define N_MAMA_MACS 3
#endif
template<class T>
using MamaShare_ = MamaShare<T, N_MAMA_MACS>;
template<int L, int N_MACS>
int run(OnlineMachine& machine)
{
return machine.run<MamaShare<gfp_<0, L>, N_MACS>, MamaShare<gf2n, N_MACS>>();
}
int main(int argc, const char** argv)
{
ez::ezOptionParser opt;
DishonestMajorityFieldMachine<MamaShare_, MamaShare_>(argc, argv, opt);
OnlineOptions& online_opts = OnlineOptions::singleton;
online_opts = {opt, argc, argv, MamaShare<gfp_<0, 1>, 3>()};
DishonestMajorityMachine machine(argc, argv, opt, online_opts, 0);
int length = min(online_opts.prime_length() - 1, machine.get_lg2());
int n_macs = DIV_CEIL(online_opts.security_parameter, length);
n_macs = 1 << int(ceil(log2(n_macs)));
if (n_macs > 4)
n_macs = 10;
if (online_opts.prime_limbs() == 1)
{
#define X(N) if (n_macs == N) return run<1, N>(machine);
X(1) X(2) X(4) X(10)
}
if (online_opts.prime_limbs() == 2)
return run<2, 1>(machine);
cerr << "Not compiled for choice of parameters" << endl;
exit(1);
}

View File

@@ -30,7 +30,11 @@ int main(int argc, const char** argv)
opt.get("-SP")->getInt(s);
opt.resetArgs();
RingOptions ring_options(opt, argc, argv);
int k = ring_options.R;
OnlineOptions& online_opts = OnlineOptions::singleton;
online_opts = {opt, argc, argv, Spdz2kShare<64, 64>(), true};
DishonestMajorityMachine machine(argc, argv, opt, online_opts, gf2n());
int k = ring_options.ring_size_from_opts_or_schedule(online_opts.progname);
#ifdef VERBOSE
cerr << "Using SPDZ2k with ring length " << k << " and security parameter "
<< s << endl;
@@ -39,7 +43,7 @@ int main(int argc, const char** argv)
#undef Z
#define Z(K, S) \
if (s == S and k == K) \
return spdz_main<Spdz2kShare<K, S>, Share<gf2n>>(argc, argv, opt);
return machine.run<Spdz2kShare<K, S>, Share<gf2n>>();
Z(64, 64)
Z(64, 48)

View File

@@ -30,6 +30,6 @@ int main(int argc, const char** argv)
{
ez::ezOptionParser opt;
OnlineOptions opts(opt, argc, argv);
gf2n_short::init_minimum(opts.security_parameter);
GC::simple_binary_main<GC::TinierSecret<gf2n_short>>(argc, argv, 1000);
gf2n_mac_key::init_minimum(opts.security_parameter);
GC::simple_binary_main<GC::TinierSecret<gf2n_mac_key>>(argc, argv, 1000);
}

View File

@@ -14,14 +14,13 @@ FHEOBJS = $(patsubst %.cpp,%.o,$(wildcard FHEOffline/*.cpp FHE/*.cpp)) Protocols
GC = $(patsubst %.cpp,%.o,$(wildcard GC/*.cpp)) $(PROCESSOR)
GC_SEMI = GC/SemiPrep.o GC/square64.o
OT = $(patsubst %.cpp,%.o,$(wildcard OT/*.cpp))
OT = $(patsubst %.cpp,%.o,$(wildcard OT/*.cpp)) $(LIBSIMPLEOT)
OT_EXE = ot.x ot-offline.x
COMMONOBJS = $(MATH) $(TOOLS) $(NETWORK) GC/square64.o Processor/OnlineOptions.o Processor/BaseMachine.o Processor/DataPositions.o Processor/ThreadQueues.o Processor/ThreadQueue.o
COMPLETE = $(COMMON) $(PROCESSOR) $(FHEOFFLINE) $(TINYOTOFFLINE) $(GC) $(OT)
YAO = $(patsubst %.cpp,%.o,$(wildcard Yao/*.cpp)) $(OT) BMR/Key.o
BMR = $(patsubst %.cpp,%.o,$(wildcard BMR/*.cpp BMR/network/*.cpp))
MINI_OT = OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT)
VMOBJS = $(PROCESSOR) $(COMMONOBJS) GC/square64.o GC/Instruction.o OT/OTTripleSetup.o OT/BaseOT.o $(LIBSIMPLEOT)
VM = $(MINI_OT) $(SHAREDLIB)
COMMON = $(SHAREDLIB)
@@ -33,13 +32,12 @@ LIB = libSPDZ.a
SHAREDLIB = libSPDZ.so
FHEOFFLINE = libFHE.so
LIBRELEASE = librelease.a
LIBSIMPLEOT_C = deps/SimplestOT_C/ref10/libSimplestOT.a
LIBSIMPLEOT += $(LIBSIMPLEOT_C)
ifeq ($(AVX_OT), 0)
VM += ECDSA/P256Element.o
OT += ECDSA/P256Element.o
MINI_OT += ECDSA/P256Element.o
else
LIBSIMPLEOT = SimpleOT/libsimpleot.a
ifeq ($(AVX_OT), 1)
LIBSIMPLEOT_ASM = deps/SimpleOT/libsimpleot.a
LIBSIMPLEOT += $(LIBSIMPLEOT_ASM)
endif
# used for dependency generation
@@ -106,11 +104,11 @@ else
tldr: mpir linux-machine-setup
endif
tldr:
tldr: libote
$(MAKE) mascot-party.x
ifeq ($(MACHINE), aarch64)
tldr: simde/simde
ifeq ($(ARM), 1)
Tools/intrinsics.h: deps/simde/simde
endif
shamir: shamir-party.x malicious-shamir-party.x atlas-party.x galois-degree.x
@@ -192,7 +190,7 @@ Fake-Offline.x: Utils/Fake-Offline.o $(VM)
$(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS)
%.x: Machines/%.o $(MINI_OT) $(SHAREDLIB)
$(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS)
$(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS) $(SHAREDLIB)
%-ecdsa-party.x: ECDSA/%-ecdsa-party.o ECDSA/P256Element.o $(VM)
$(CXX) -o $@ $(CFLAGS) $^ $(LDLIBS)
@@ -258,47 +256,85 @@ static/no-party.x: Protocols/ShareInterface.o
Test/failure.x: Protocols/MalRepRingOptions.o
ifeq ($(AVX_OT), 1)
$(LIBSIMPLEOT): SimpleOT/Makefile
$(MAKE) -C SimpleOT
$(LIBSIMPLEOT_ASM): deps/SimpleOT/Makefile
$(MAKE) -C deps/SimpleOT
OT/BaseOT.o: SimpleOT/Makefile
OT/BaseOT.o: deps/SimpleOT/Makefile
SimpleOT/Makefile:
git submodule update --init SimpleOT
deps/SimpleOT/Makefile:
git submodule update --init deps/SimpleOT || git clone https://github.com/mkskeller/SimpleOT deps/SimpleOT
endif
$(LIBSIMPLEOT_C): deps/SimplestOT_C/ref10/Makefile
$(MAKE) -C deps/SimplestOT_C/ref10
OT/BaseOT.o: deps/SimplestOT_C/ref10/Makefile
deps/SimplestOT_C/ref10/Makefile:
git submodule update --init deps/SimplestOT_C || git clone https://github.com/mkskeller/SimplestOT_C deps/SimplestOT_C
cd deps/SimplestOT_C/ref10; cmake .
.PHONY: Programs/Circuits
Programs/Circuits:
git submodule update --init Programs/Circuits
.PHONY: mpir-setup mpir-global mpir
mpir-setup:
git submodule update --init mpir || git clone https://github.com/wbhart/mpir
cd mpir; \
git submodule update --init deps/mpir || git clone https://github.com/wbhart/mpir deps/mpir
cd deps/mpir; \
autoreconf -i; \
autoreconf -i
- $(MAKE) -C mpir clean
- $(MAKE) -C deps/mpir clean
mpir-global: mpir-setup
cd mpir; \
cd deps/mpir; \
./configure --enable-cxx;
$(MAKE) -C mpir
sudo $(MAKE) -C mpir install
$(MAKE) -C deps/mpir
sudo $(MAKE) -C deps/mpir install
mpir: mpir-setup
cd mpir; \
cd deps/mpir; \
./configure --enable-cxx --prefix=$(CURDIR)/local
$(MAKE) -C mpir install
$(MAKE) -C deps/mpir install
-echo MY_CFLAGS += -I./local/include >> CONFIG.mine
-echo MY_LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib >> CONFIG.mine
deps/libOTe/libOTe:
git submodule update --init --recursive deps/libOTe
cd deps/libOTe; \
python3 build.py --setup --boost --install=$(CURDIR)/local
-echo MY_CFLAGS += -I./local/include >> CONFIG.mine
-echo MY_LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib >> CONFIG.mine
OTE_OPTS = -DENABLE_SOFTSPOKEN_OT=ON -DCMAKE_CXX_COMPILER=$(CXX)
ifeq ($(ARM), 1)
libote: deps/libOTe/libOTe
cd deps/libOTe; \
PATH=$(CURDIR)/local/bin:$(PATH) python3 build.py --install=$(CURDIR)/local -- -DBUILD_SHARED_LIBS=0 -DENABLE_AVX=OFF -DENABLE_SSE=OFF $(OTE_OPTS)
else
libote: deps/libOTe/libOTe
cd deps/libOTe; \
PATH=$(CURDIR)/local/bin:$(PATH) python3 build.py --install=$(CURDIR)/local -- -DBUILD_SHARED_LIBS=0 $(OTE_OPTS)
endif
libote-shared: deps/libOTe/libOTe
cd deps/libOTe; \
python3 build.py --install=$(CURDIR)/local -- -DBUILD_SHARED_LIBS=1 $(OTE_OPTS)
cmake:
wget https://github.com/Kitware/CMake/releases/download/v3.24.1/cmake-3.24.1.tar.gz
tar xzvf cmake-3.24.1.tar.gz
cd cmake-3.24.1; \
./bootstrap --parallel=8 --prefix=../local && make && make install
mac-setup: mac-machine-setup
brew install openssl boost libsodium mpir yasm ntl
brew install openssl boost libsodium mpir yasm ntl cmake
-echo MY_CFLAGS += -I/usr/local/opt/openssl/include -I`brew --prefix`/opt/openssl/include -I`brew --prefix`/include >> CONFIG.mine
-echo MY_LDLIBS += -L/usr/local/opt/openssl/lib -L`brew --prefix`/lib -L`brew --prefix`/opt/openssl/lib >> CONFIG.mine
# -echo USE_NTL = 1 >> CONFIG.mine
ifeq ($(MACHINE), aarch64)
ifeq ($(ARM), 1)
mac-machine-setup:
-echo ARCH = >> CONFIG.mine
linux-machine-setup:
@@ -308,8 +344,8 @@ mac-machine-setup:
linux-machine-setup:
endif
simde/simde:
git submodule update --init simde || git clone https://github.com/simd-everywhere/simde
deps/simde/simde:
git submodule update --init deps/simde || git clone https://github.com/simd-everywhere/simde deps/simde
clean:
-rm -f */*.o *.o */*.d *.d *.x core.* *.a gmon.out */*/*.o static/*.x *.so

View File

@@ -30,6 +30,8 @@ public:
Square& rsub(const Square& other);
Square& sub(const void* other);
void bit_sub(const BitVector& bits, int start);
void randomize(int row, PRNG& G) { rows[row].randomize(G); }
void conditional_add(BitVector& conditions, Square& other,
int offset);

View File

@@ -31,6 +31,15 @@ Square<U>& Square<U>::sub(const void* other)
return *this;
}
template<class U>
void Square<U>::bit_sub(const BitVector& bits, int start)
{
for (int i = 0; i < U::length(); i++)
{
rows[i] -= bits.get_bit(start + i);
}
}
template<class U>
void Square<U>::conditional_add(BitVector& conditions,
Square<U>& other, int offset)

View File

@@ -6,6 +6,8 @@
void Zp_Data::init(const bigint& p,bool mont)
{
lock.lock();
#ifdef VERBOSE
if (pr != 0)
{
@@ -57,6 +59,8 @@ void Zp_Data::init(const bigint& p,bool mont)
}
inline_mpn_zero(prA,MAX_MOD_SZ+1);
mpn_copyi(prA,pr.get_mpz_t()->_mp_d,t);
lock.unlock();
}

View File

@@ -14,6 +14,7 @@
#include "Math/mpn_fixed.h"
#include "Tools/random.h"
#include "Tools/intrinsics.h"
#include "Tools/Lock.h"
#include <iostream>
using namespace std;
@@ -36,6 +37,7 @@ class Zp_Data
mp_limb_t prA[MAX_MOD_SZ+1];
int t; // More Montgomery data
mp_limb_t overhang;
Lock lock;
template <int T>
void Mont_Mult_(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y) const;

View File

@@ -115,7 +115,7 @@ public:
{ return mpz_sizeinbase(get_mpz_t(), 2); }
void generateUniform(PRNG& G, int n_bits, bool positive = false)
{ G.get_bigint(*this, n_bits, positive); }
{ G.get(*this, n_bits, positive); }
void pack(octetStream& os) const { os.store(*this); }
void unpack(octetStream& os) { os.get(*this); };

View File

@@ -51,7 +51,7 @@ void ssl_error(string side, string other, string me)
}
CryptoPlayer::CryptoPlayer(const Names& Nms, const string& id_base) :
MultiPlayer<ssl_socket*>(Nms),
MultiPlayer<ssl_socket*>(Nms, id_base),
ctx("P" + to_string(my_num()))
{
sockets.resize(num_players());
@@ -137,6 +137,34 @@ void CryptoPlayer::receive_player_no_stats(int other, octetStream& o) const
receivers[other]->wait(o);
}
size_t CryptoPlayer::send_no_stats(int player, const PlayerBuffer& buffer,
bool block) const
{
assert(player != my_num());
auto socket = senders.at(player)->get_socket();
if (block)
{
send(socket, buffer.data, buffer.size);
return buffer.size;
}
else
return send_non_blocking(socket, buffer.data, buffer.size);
}
size_t CryptoPlayer::recv_no_stats(int player, const PlayerBuffer& buffer,
bool block) const
{
assert(player != my_num());
auto socket = receivers.at(player)->get_socket();
if (block)
{
receive(socket, buffer.data, buffer.size);
return buffer.size;
}
else
return receive_non_blocking(socket, buffer.data, buffer.size);
}
void CryptoPlayer::exchange_no_stats(int other, const octetStream& to_send,
octetStream& to_receive) const
{

View File

@@ -46,6 +46,11 @@ public:
void send_to_no_stats(int other, const octetStream& o) const;
void receive_player_no_stats(int other, octetStream& o) const;
size_t send_no_stats(int player, const PlayerBuffer& buffer,
bool block) const;
size_t recv_no_stats(int player, const PlayerBuffer& buffer,
bool block) const;
void exchange_no_stats(int other, const octetStream& to_send,
octetStream& to_receive) const;

View File

@@ -216,15 +216,15 @@ Player::Player(const Names& Nms) :
template<class T>
MultiPlayer<T>::MultiPlayer(const Names& Nms) :
Player(Nms), send_to_self_socket(0)
MultiPlayer<T>::MultiPlayer(const Names& Nms, const string& id) :
Player(Nms), id(id), send_to_self_socket(0)
{
sockets.resize(Nms.num_players());
}
PlainPlayer::PlainPlayer(const Names& Nms, const string& id) :
MultiPlayer<int>(Nms)
MultiPlayer<int>(Nms, id)
{
if (Nms.num_players() > 1)
setup_sockets(Nms.names, Nms.ports, id, *Nms.server);
@@ -399,6 +399,30 @@ void Player::receive_player(int i, FlexBuffer& buffer) const
buffer = os;
}
size_t PlainPlayer::send_no_stats(int player,
const PlayerBuffer& buffer, bool block) const
{
if (block)
{
send(socket(player), buffer.data, buffer.size);
return buffer.size;
}
else
return send_non_blocking(socket(player), buffer.data, buffer.size);
}
size_t PlainPlayer::recv_no_stats(int player,
const PlayerBuffer& buffer, bool block) const
{
if (block)
{
receive(socket(player), buffer.data, buffer.size);
return buffer.size;
}
else
return receive_non_blocking(socket(player), buffer.data, buffer.size);
}
void Player::send_relative(const vector<octetStream>& os) const
{
@@ -647,10 +671,8 @@ void ThreadPlayer::send_all(const octetStream& o) const
RealTwoPartyPlayer::RealTwoPartyPlayer(const Names& Nms, int other_player, const string& id) :
TwoPartyPlayer(Nms.my_num()), other_player(other_player)
VirtualTwoPartyPlayer(*(P = new PlainPlayer(Nms, id + "2")), other_player)
{
is_server = Nms.my_num() > other_player;
setup_sockets(other_player, Nms, Nms.ports[other_player], id);
}
RealTwoPartyPlayer::RealTwoPartyPlayer(const Names& Nms, int other_player,
@@ -660,40 +682,7 @@ RealTwoPartyPlayer::RealTwoPartyPlayer(const Names& Nms, int other_player,
RealTwoPartyPlayer::~RealTwoPartyPlayer()
{
close_client_socket(socket);
}
void RealTwoPartyPlayer::setup_sockets(int other_player, const Names &nms, int portNum, string id)
{
id += "2";
const char *hostname = nms.names[other_player].c_str();
ServerSocket *server = nms.server;
if (is_server) {
#ifdef DEBUG_NETWORKING
fprintf(stderr, "Setting up server with id %s\n", id.c_str());
#endif
socket = server->get_connection_socket(id);
}
else {
#ifdef DEBUG_NETWORKING
fprintf(stderr, "Setting up client to %s:%d with id %s\n", hostname,
portNum, id.c_str());
#endif
set_up_client_socket(socket, hostname, portNum);
octetStream(id).Send(socket);
}
}
int RealTwoPartyPlayer::other_player_num() const
{
return other_player;
}
void RealTwoPartyPlayer::send(octetStream& o) const
{
TimeScope ts(comm_stats["Sending one-to-one"].add(o));
o.Send(socket);
sent += o.get_length();
delete P;
}
void VirtualTwoPartyPlayer::send(octetStream& o) const
@@ -703,14 +692,6 @@ void VirtualTwoPartyPlayer::send(octetStream& o) const
comm_stats.sent += o.get_length();
}
void RealTwoPartyPlayer::receive(octetStream& o) const
{
TimeScope ts(timer);
o.reset_write_head();
o.Receive(socket);
comm_stats["Receiving one-to-one"].add(o, ts);
}
void VirtualTwoPartyPlayer::receive(octetStream& o) const
{
TimeScope ts(timer);
@@ -718,29 +699,6 @@ void VirtualTwoPartyPlayer::receive(octetStream& o) const
comm_stats["Receiving one-to-one"].add(o, ts);
}
void RealTwoPartyPlayer::send_receive_player(vector<octetStream>& o) const
{
{
if (is_server)
{
send(o[0]);
receive(o[1]);
}
else
{
receive(o[1]);
send(o[0]);
}
}
}
void RealTwoPartyPlayer::exchange(octetStream& o) const
{
TimeScope ts(comm_stats["Exchanging one-to-one"].add(o));
sent += o.get_length();
o.exchange(socket, socket);
}
void VirtualTwoPartyPlayer::send_receive_player(vector<octetStream>& o) const
{
TimeScope ts(comm_stats["Exchanging one-to-one"].add(o[0]));
@@ -754,6 +712,25 @@ VirtualTwoPartyPlayer::VirtualTwoPartyPlayer(Player& P, int other_player) :
{
}
size_t VirtualTwoPartyPlayer::send(const PlayerBuffer& buffer, bool block) const
{
auto sent = P.send_no_stats(other_player, buffer, block);
lock.lock();
comm_stats["Sending one-to-one"].add(sent);
comm_stats.sent += sent;
lock.unlock();
return sent;
}
size_t VirtualTwoPartyPlayer::recv(const PlayerBuffer& buffer, bool block) const
{
auto received = P.recv_no_stats(other_player, buffer, block);
lock.lock();
comm_stats["Receiving one-to-one"].add(received);
lock.unlock();
return received;
}
void OffsetPlayer::send_receive_player(vector<octetStream>& o) const
{
P.exchange(P.get_player(offset), o[0], o[1]);

View File

@@ -22,6 +22,8 @@ using namespace std;
#include "Networking/Receiver.h"
#include "Networking/Sender.h"
#include "Tools/ezOptionParser.h"
#include "Networking/PlayerBuffer.h"
#include "Tools/Lock.h"
template<class T> class MultiPlayer;
class Server;
@@ -225,6 +227,8 @@ public:
Player(const Names& Nms);
virtual ~Player();
virtual string get_id() const { throw not_implemented(); }
/**
* Get number of players
*/
@@ -266,6 +270,11 @@ public:
virtual void receive_player_no_stats(int i,octetStream& o) const = 0;
virtual void receive_player(int i,FlexBuffer& buffer) const;
virtual size_t send_no_stats(int, const PlayerBuffer&, bool) const
{ throw not_implemented(); }
virtual size_t recv_no_stats(int, const PlayerBuffer&, bool) const
{ throw not_implemented(); }
/**
* Send to all other players by offset.
* ``o[0]`` gets sent to the next player etc.
@@ -389,6 +398,8 @@ public:
template<class T>
class MultiPlayer : public Player
{
string id;
protected:
vector<T> sockets;
T send_to_self_socket;
@@ -399,10 +410,12 @@ protected:
friend class CryptoPlayer;
public:
MultiPlayer(const Names& Nms);
MultiPlayer(const Names& Nms, const string& id);
virtual ~MultiPlayer();
string get_id() const { return id; }
// Send/Receive data to/from player i
void send_long(int i, long a) const;
long receive_long(int i) const;
@@ -448,6 +461,9 @@ public:
// legacy interface
PlainPlayer(const Names& Nms, int id_base = 0);
~PlainPlayer();
size_t send_no_stats(int player, const PlayerBuffer& buffer, bool block) const;
size_t recv_no_stats(int player, const PlayerBuffer& buffer, bool block) const;
};
@@ -481,39 +497,11 @@ public:
virtual void receive(octetStream& o) const = 0;
virtual void send_receive_player(vector<octetStream>& o) const = 0;
void Broadcast_Receive(vector<octetStream>& o) const;
};
class RealTwoPartyPlayer : public TwoPartyPlayer
{
private:
// setup sockets for comm. with only one other player
void setup_sockets(int other_player, const Names &nms, int portNum, string id);
int socket;
bool is_server;
int other_player;
public:
RealTwoPartyPlayer(const Names& Nms, int other_player, const string& id);
// legacy
RealTwoPartyPlayer(const Names& Nms, int other_player, int id_base = 0);
~RealTwoPartyPlayer();
void send(octetStream& o) const;
void receive(octetStream& o) const;
int other_player_num() const;
int my_num() const { return is_server; }
int num_players() const { return 2; }
/* Send and receive to/from the other player
* - o[0] contains my data, received data put in o[1]
*/
void send_receive_player(vector<octetStream>& o) const;
void exchange(octetStream& o) const;
void exchange(int other, octetStream& o) const { (void)other; exchange(o); }
void pass_around(octetStream& o, int offset = 1) const { (void)offset; exchange(o); }
virtual size_t send(const PlayerBuffer&, bool) const
{ throw not_implemented(); }
virtual size_t recv(const PlayerBuffer&, bool) const
{ throw not_implemented(); }
};
// for different threads, separate statistics
@@ -523,6 +511,8 @@ class VirtualTwoPartyPlayer : public TwoPartyPlayer
int other_player;
NamedCommStats& comm_stats;
mutable Lock lock;
public:
VirtualTwoPartyPlayer(Player& P, int other_player);
@@ -536,6 +526,20 @@ public:
void send_receive_player(vector<octetStream>& o) const;
void pass_around(octetStream& o, int _ = 1) const { (void)_, (void) o; throw not_implemented(); }
size_t send(const PlayerBuffer& buffer, bool block) const;
size_t recv(const PlayerBuffer& buffer, bool block) const;
};
class RealTwoPartyPlayer : public VirtualTwoPartyPlayer
{
PlainPlayer* P;
public:
RealTwoPartyPlayer(const Names& Nms, int other_player, const string& id);
// legacy
RealTwoPartyPlayer(const Names& Nms, int other_player, int id_base = 0);
~RealTwoPartyPlayer();
};
// for the same thread

23
Networking/PlayerBuffer.h Normal file
View File

@@ -0,0 +1,23 @@
/*
* PlayerBuffer.h
*
*/
#ifndef NETWORKING_PLAYERBUFFER_H_
#define NETWORKING_PLAYERBUFFER_H_
#include "Tools/int.h"
class PlayerBuffer
{
public:
octet* data;
size_t size;
PlayerBuffer(octet* data, size_t size) :
data(data), size(size)
{
}
};
#endif /* NETWORKING_PLAYERBUFFER_H_ */

169
Networking/PlayerCtSocket.h Normal file
View File

@@ -0,0 +1,169 @@
/*
* PlayerSocket.h
*
*/
#ifndef NETWORKING_PLAYERCTSOCKET_H_
#define NETWORKING_PLAYERCTSOCKET_H_
#include "Player.h"
#include "Tools/Lock.h"
#include <cryptoTools/Network/SocketAdapter.h>
class PlayerCtSocket : public osuCrypto::SocketInterface
{
class Pack
{
public:
deque<PlayerBuffer> buffers;
osuCrypto::io_completion_handle fn;
size_t total;
Pack() :
total(0)
{
}
Pack(osuCrypto::io_completion_handle& fn,
gsl::span<boost::asio::mutable_buffer> buffers) :
fn(fn),
total(0)
{
for (auto& buffer : buffers)
{
auto data = boost::asio::buffer_cast<osuCrypto::u8*>(buffer);
auto size = boost::asio::buffer_size(buffer);
this->buffers.push_back({data, size});
}
}
};
TwoPartyPlayer& P;
WaitQueue<Pack> send_packs, receive_packs;
pthread_t send_thread, receive_thread;
static void* run_send(void* socket)
{
((PlayerCtSocket*) socket)->send();
return 0;
}
static void* run_receive(void* socket)
{
((PlayerCtSocket*) socket)->receive();
return 0;
}
void debug(const char* msg)
{
(void) msg;
#ifdef DEBUG_CT
printf("%p %lx %s\n", this, pthread_self(), msg);
#endif
}
void debug(const char* msg, size_t n)
{
(void) msg, (void) n;
#ifdef DEBUG_CT
printf("%p %lx %s %lu\n", this, pthread_self(), msg, n);
#endif
}
public:
PlayerCtSocket(TwoPartyPlayer& P) :
P(P)
{
pthread_create(&send_thread, 0, run_send, this);
pthread_create(&receive_thread, 0, run_receive, this);
}
~PlayerCtSocket()
{
send_packs.stop();
receive_packs.stop();
pthread_join(send_thread, 0);
pthread_join(receive_thread, 0);
}
void async_send(gsl::span<boost::asio::mutable_buffer> buffers,
osuCrypto::io_completion_handle&& fn) override
{
debug("async send");
send_packs.push(Pack(fn, buffers));
}
void async_recv(gsl::span<boost::asio::mutable_buffer> buffers,
osuCrypto::io_completion_handle&& fn) override
{
debug("async recv");
receive_packs.push(Pack(fn, buffers));
}
void send()
{
Pack pack;
while (send_packs.pop(pack))
{
#ifdef DEBUG_CT
debug("got to send", send_packs.size());
#endif
while (not pack.buffers.empty())
{
auto& buffer = pack.buffers.front();
auto sent = P.send(buffer, true);
buffer.data += sent;
buffer.size -= sent;
pack.total += sent;
#ifdef DEBUG_CT
printf("%p %lx sent %lu total %lu left %lu\n", this, pthread_self(), sent, pack.total, buffer.size);
if (sent == 4)
debug("content", *(word*)(buffer.data - sent));
#endif
if (buffer.size == 0)
pack.buffers.pop_front();
}
{
boost::system::error_code ec;
auto total = pack.total;
auto fn = pack.fn;
debug("send callback", total);
fn(ec, total);
}
}
}
void receive()
{
Pack pack;
while (receive_packs.pop(pack))
{
debug("got to receive");
while (not pack.buffers.empty())
{
auto& buffer = pack.buffers.front();
auto sent = P.recv(buffer, true);
buffer.data += sent;
buffer.size -= sent;
pack.total += sent;
#ifdef DEBUG_CT
printf("%p %lx received %lu total %lu left %lu\n", this, pthread_self(), sent, pack.total, buffer.size);
if (sent == 4)
debug("content", *(word*)(buffer.data - sent));
#endif
if (buffer.size == 0)
pack.buffers.pop_front();
}
{
boost::system::error_code ec;
auto total = pack.total;
auto fn = pack.fn;
debug("recv callback", total);
fn(ec, total);
}
}
}
};
#endif /* NETWORKING_PLAYERCTSOCKET_H_ */

View File

@@ -35,6 +35,11 @@ public:
Receiver(T socket);
~Receiver();
T get_socket()
{
return socket;
}
void request(octetStream& os);
void wait(octetStream& os);
};

View File

@@ -35,6 +35,11 @@ public:
Sender(T socket);
~Sender();
T get_socket()
{
return socket;
}
void request(const octetStream& os);
void wait(const octetStream& os);
};

View File

@@ -8,18 +8,14 @@
#include <fstream>
#include <pthread.h>
#ifndef NO_AVX_OT
extern "C" {
#ifndef NO_AVX_OT
#include "SimpleOT/ot_sender.h"
#include "SimpleOT/ot_receiver.h"
#endif
#include "SimplestOT_C/ref10/ot_sender.h"
#include "SimplestOT_C/ref10/ot_receiver.h"
}
#endif
#include "ECDSA/P256Element.h"
#ifdef USE_RISTRETTO
#include "ECDSA/CurveElement.h"
#endif
using namespace std;
@@ -76,86 +72,60 @@ void send_if_ot_receiver(TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE rol
}
}
// type-dependent redirection
void sender_genS(ref10_SENDER* s, unsigned char* S_pack)
{
ref10_sender_genS(s, S_pack);
}
void sender_keygen(ref10_SENDER* s, unsigned char* Rs_pack,
unsigned char (*keys)[4][HASHBYTES])
{
ref10_sender_keygen(s, Rs_pack, keys);
}
void receiver_maketable(ref10_RECEIVER* r)
{
ref10_receiver_maketable(r);
}
void receiver_procS(ref10_RECEIVER* r)
{
ref10_receiver_procS(r);
}
void receiver_rsgen(ref10_RECEIVER* r, unsigned char* Rs_pack,
unsigned char* cs)
{
ref10_receiver_rsgen(r, Rs_pack, cs);
}
void receiver_keygen(ref10_RECEIVER* r, unsigned char (*keys)[HASHBYTES])
{
ref10_receiver_keygen(r, keys);
}
void BaseOT::exec_base(bool new_receiver_inputs)
{
Bundle<octetStream> bundle(*P);
#ifdef NO_AVX_OT
bundle.mine = string("OT without AVX");
#else
bundle.mine = string("OT with AVX");
#ifndef NO_AVX_OT
if (cpu_has_avx(true))
exec_base<SIMPLEOT_SENDER, SIMPLEOT_RECEIVER>(new_receiver_inputs);
else
#endif
try
{
bundle.compare(*P);
}
catch (mismatch_among_parties&)
{
cerr << "Parties compiled with different base OT algorithms" << endl;
cerr << "Set \"AVX_OT\" to the same value on all parties" << endl;
exit(1);
}
#ifdef NO_AVX_OT
#ifdef USE_RISTRETTO
typedef CurveElement Element;
#else
typedef P256Element Element;
#endif
Element::init();
vector<Element::Scalar> as, bs;
vector<Element> As;
SeededPRNG G;
vector<octetStream> os(2);
if (ot_role & SENDER)
for (int i = 0; i < nOT; i++)
{
as.push_back(G.get<Element::Scalar>());
As.push_back(as.back());
As.back().pack(os[0]);
}
send_if_ot_sender(P, os, ot_role);
os[0].reset_write_head();
if (ot_role & RECEIVER)
for (int i = 0; i < nOT; i++)
{
if (new_receiver_inputs)
receiver_inputs[i] = G.get_bit();
auto b = G.get<Element::Scalar>();
Element B = b;
auto A = os[1].get<Element>();
if (receiver_inputs[i])
B += A;
B.pack(os[0]);
receiver_outputs[i] = (A * b).hash(AES_BLK_SIZE);
}
send_if_ot_receiver(P, os, ot_role);
if (ot_role & SENDER)
for (int i = 0; i < nOT; i++)
{
auto B = os[1].get<Element>();
sender_inputs.at(i).at(0) = (B * as[i]).hash(AES_BLK_SIZE);
sender_inputs.at(i).at(1) = ((B - As[i]) * as[i]).hash(AES_BLK_SIZE);
}
#else
if (not cpu_has_avx(true))
throw runtime_error("SimpleOT needs AVX support");
exec_base<ref10_SENDER, ref10_RECEIVER>(new_receiver_inputs);
}
template<class T, class U>
void BaseOT::exec_base(bool new_receiver_inputs)
{
int i, j, k;
size_t len;
PRNG G;
G.ReSeed();
vector<octetStream> os(2);
SIMPLEOT_SENDER sender;
SIMPLEOT_RECEIVER receiver;
T sender;
U receiver;
unsigned char S_pack[ PACKBYTES ];
unsigned char Rs_pack[ 2 ][ 4 * PACKBYTES ];
@@ -188,7 +158,7 @@ void BaseOT::exec_base(bool new_receiver_inputs)
{
if (ot_role & RECEIVER)
{
for (j = 0; j < 4; j++)
for (j = 0; j < 4 and (i + j) < nOT; j++)
{
if (new_receiver_inputs)
receiver_inputs[i + j] = G.get_uchar()&1;
@@ -199,7 +169,7 @@ void BaseOT::exec_base(bool new_receiver_inputs)
receiver_keygen(&receiver, receiver_keys);
// Copy keys to receiver_outputs
for (j = 0; j < 4; j++)
for (j = 0; j < 4 and (i + j) < nOT; j++)
{
for (k = 0; k < AES_BLK_SIZE; k++)
{
@@ -236,7 +206,7 @@ void BaseOT::exec_base(bool new_receiver_inputs)
sender_keygen(&sender, Rs_pack[1], sender_keys);
// Copy 128 bits of keys to sender_inputs
for (j = 0; j < 4; j++)
for (j = 0; j < 4 and (i + j) < nOT; j++)
{
for (k = 0; k < AES_BLK_SIZE; k++)
{
@@ -261,7 +231,6 @@ void BaseOT::exec_base(bool new_receiver_inputs)
printf("\n");
#endif
}
#endif
for (int i = 0; i < nOT; i++)
{

View File

@@ -33,7 +33,7 @@ class BaseOT
public:
BitVector receiver_inputs;
vector< vector<BitVector> > sender_inputs;
vector< array<BitVector, 2> > sender_inputs;
vector<BitVector> receiver_outputs;
TwoPartyPlayer* P;
int nOT, ot_length;
@@ -43,9 +43,9 @@ public:
: P(player), nOT(nOT), ot_length(ot_length), ot_role(role)
{
receiver_inputs.resize(nOT);
sender_inputs.resize(nOT, vector<BitVector>(2));
sender_inputs.resize(nOT);
receiver_outputs.resize(nOT);
G_sender.resize(nOT, vector<PRNG>(2));
G_sender.resize(nOT);
G_receiver.resize(nOT);
for (int i = 0; i < nOT; i++)
@@ -88,11 +88,14 @@ public:
void check();
protected:
vector< vector<PRNG> > G_sender;
vector< array<PRNG, 2> > G_sender;
vector<PRNG> G_receiver;
bool is_sender() { return (bool) (ot_role & SENDER); }
bool is_receiver() { return (bool) (ot_role & RECEIVER); }
template<class T, class U>
void exec_base(bool new_receiver_inputs=true);
};
class FakeOT : public BaseOT

View File

@@ -7,6 +7,7 @@
#define OT_BITMATRIX_H_
#include "Tools/intrinsics.h"
#include "Tools/Exceptions.h"
#include <vector>
#include <iostream>
@@ -56,6 +57,8 @@ union square128 {
square128& sub(const __m128i* other);
square128& sub(const void* other) { return sub((__m128i*)other); }
void bit_sub(const BitVector&, int) { throw not_implemented(); }
void randomize(PRNG& G);
void randomize(int row, PRNG& G);
void conditional_add(BitVector& conditions, square128& other, int offset);

View File

@@ -122,8 +122,11 @@ Slice<U>& Slice<U>::sub(BitVector& other, int repeat)
throw invalid_length(to_string(U::PartType::n_columns()));
for (size_t i = start; i < end; i++)
{
bm.squares[i].sub(other.get_ptr_to_byte(i / repeat,
U::PartType::n_row_bytes()));
if (repeat > 0)
bm.squares[i].sub(other.get_ptr_to_byte(i / repeat,
U::PartType::n_row_bytes()));
else
bm.squares[i].bit_sub(other, i * U::PartType::n_rows());
}
return *this;
}

View File

@@ -56,6 +56,11 @@ public:
return *this;
}
void bit_sub(const BitVector&, int)
{
throw not_implemented();
}
void randomize(int row, PRNG& G)
{
squares[row / T::Square::n_rows()].randomize(

View File

@@ -80,7 +80,7 @@ public:
vector<typename T::Multiplier*> ot_multipliers;
//vector<OTMachine*> machines;
BitVector baseReceiverInput; // same for every set of OTs
vector< vector< vector<BitVector> > > baseSenderInputs;
vector< vector< array<BitVector, 2> > > baseSenderInputs;
vector< vector<BitVector> > baseReceiverOutputs;
vector<BitVector> valueBits;
BitVector b_padded_bits;
@@ -98,6 +98,7 @@ public:
vector<PlainTriple<open_type, N_AMPLIFY>> preampTriples;
vector<array<open_type, 3>> plainTriples;
vector<open_type> plainBits;
typename T::MAC_Check* MC;
@@ -112,6 +113,8 @@ public:
void generatePlainTriples();
void plainTripleRound(int k = 0);
void generatePlainBits();
void run_multipliers(MultJob job);
mac_key_type get_mac_key() const { return mac_key; }

View File

@@ -24,6 +24,7 @@
template<class T>
void* run_ot_thread(void* ptr)
{
bigint::init_thread();
((OTMultiplierBase*)ptr)->multiply();
return NULL;
}
@@ -379,8 +380,7 @@ void Spdz2kTripleGenerator<T>::generateTriples()
b_padded_bits.resize(8 * Z2<K + 2 * S>::N_BYTES * (nTriplesPerLoop + 1));
vector< PlainTriple_<Z2<K + 2 * S>, Z2<K + S>, 2> > amplifiedTriples(nTriplesPerLoop);
uncheckedTriples.resize(nTriplesPerLoop);
MAC_Check_Z2k<Z2<K + 2 * S>, Z2<S>, Z2<K + S>, Share<Z2<K + 2 * S>> > MC(
this->get_mac_key());
typename T::prep_check_type::MAC_Check MC(this->get_mac_key());
this->start_progress();
@@ -481,6 +481,35 @@ void OTTripleGenerator<U>::generatePlainTriples()
plainTripleRound(i);
}
template<class T>
void OTTripleGenerator<T>::generatePlainBits()
{
assert(ot_multipliers.size() == 1);
machine.set_passive();
machine.output = false;
int n = multiple_minimum(nPreampTriplesPerLoop, T::open_type::size_in_bits());
valueBits.resize(1);
valueBits[0].resize(n);
valueBits[0].randomize(share_prg);
signal_multipliers(DATA_BIT);
wait_for_multipliers();
plainBits.clear();
for (int j = 0; j < n; j++)
{
if (j % T::open_type::size_in_bits() < T::open_type::length())
{
plainBits.push_back(valueBits[0].get_bit(j));
plainBits.back() += ot_multipliers[0]->c_output[j] * 2;
}
}
}
template<class U>
void OTTripleGenerator<U>::plainTripleRound(int k)
{
@@ -730,6 +759,7 @@ void Spdz2kTripleGenerator<W>::sacrificeZ2k(U& MC, PRNG& G)
{
typedef sacri_type T;
typedef open_type V;
typedef typename W::prep_check_type prep_check_type;
auto& machine = this->machine;
auto& nTriplesPerLoop = this->nTriplesPerLoop;
@@ -737,7 +767,7 @@ void Spdz2kTripleGenerator<W>::sacrificeZ2k(U& MC, PRNG& G)
auto& outputFile = this->outputFile;
auto& uncheckedTriples = this->uncheckedTriples;
vector< Share<T> > maskedAs(nTriplesPerLoop);
vector<prep_check_type> maskedAs(nTriplesPerLoop);
vector<TripleToSacrifice<Share<T>> > maskedTriples(nTriplesPerLoop);
for (int j = 0; j < nTriplesPerLoop; j++)
{
@@ -753,7 +783,7 @@ void Spdz2kTripleGenerator<W>::sacrificeZ2k(U& MC, PRNG& G)
MC.POpen_Begin(openedAs, maskedAs, globalPlayer);
MC.POpen_End(openedAs, maskedAs, globalPlayer);
vector<Share<T>> sigmas;
vector<prep_check_type> sigmas;
for (int j = 0; j < nTriplesPerLoop; j++) {
// compute t * [c] - [chat] - [b] * p
sigmas.push_back(maskedTriples[j].computeCheckShare(V(openedAs[j])));

View File

@@ -6,6 +6,7 @@
#include "Tools/aes.h"
#include "Tools/MMO.h"
#include "Tools/intrinsics.h"
#include "Tools/benchmarking.h"
OTExtension::OTExtension(const BaseOT& baseOT, TwoPartyPlayer* player,
@@ -26,10 +27,18 @@ int eq_m128i(__m128i a, __m128i b)
}
bool OTExtensionWithMatrix::warned = false;
void OTExtensionWithMatrix::check_correlation(int nOTs,
const BitVector& receiverInput)
{
if (not warned)
{
insecure("OT extension (security of KOS15 is unclear, "
"see https://eprint.iacr.org/2022/192.)");
warned = true;
}
//cout << "\tStarting correlation check\n" << flush;
#ifdef OTEXT_TIMER
timeval startv, endv;

View File

@@ -38,7 +38,7 @@ public:
}
void init(const BitVector& baseReceiverInput,
const vector< vector<BitVector> >& baseSenderInput,
const vector< array<BitVector, 2> >& baseSenderInput,
const vector<BitVector>& baseReceiverOutput)
{
nbaseOTs = baseReceiverInput.size();

View File

@@ -4,6 +4,13 @@
*/
#include "OTExtensionWithMatrix.h"
#include "Tools/Bundle.h"
#ifndef USE_KOS
#include "Networking/PlayerCtSocket.h"
osuCrypto::IOService OTExtensionWithMatrix::ios;
#endif
#include "OTCorrelator.hpp"
@@ -23,6 +30,43 @@ OTExtensionWithMatrix::OTExtensionWithMatrix(BaseOT& baseOT, TwoPartyPlayer* pla
{
G.ReSeed();
nsubloops = 1;
agreed = false;
#ifndef USE_KOS
channel = 0;
#endif
}
OTExtensionWithMatrix::~OTExtensionWithMatrix()
{
#ifndef USE_KOS
if (channel)
delete channel;
#endif
}
void OTExtensionWithMatrix::protocol_agreement()
{
if (agreed)
return;
Bundle<octetStream> bundle(*player);
#ifdef USE_KOS
bundle.mine = string("KOS15");
#else
bundle.mine = string("SoftSpokenOT");
#endif
player->unchecked_broadcast(bundle);
try
{
bundle.compare(*player);
}
catch (mismatch_among_parties&)
{
cerr << "Parties compiled with different OT extensions" << endl;
cerr << "Set \"USE_KOS\" to the same value on all parties" << endl;
exit(1);
}
}
void OTExtensionWithMatrix::transfer(int nOTs,
@@ -57,12 +101,103 @@ void OTExtensionWithMatrix::transfer(int nOTs,
#endif
}
void OTExtensionWithMatrix::extend(int nOTs_requested, BitVector& newReceiverInput)
void OTExtensionWithMatrix::extend(int nOTs_requested, const BitVector& newReceiverInput)
{
protocol_agreement();
#ifdef USE_KOS
extend_correlated(nOTs_requested, newReceiverInput);
hash_outputs(nOTs_requested);
#else
resize(nOTs_requested);
if (not channel)
channel = new osuCrypto::Channel(ios, new PlayerCtSocket(*player));
if (player->my_num())
{
soft_sender(nOTs_requested);
soft_receiver(nOTs_requested, newReceiverInput);
}
else
{
soft_receiver(nOTs_requested, newReceiverInput);
soft_sender(nOTs_requested);
}
channel->send("hello", 6);
char buf[6];
channel->recv(buf, 6);
assert(buf == string("hello"));
#endif
}
#ifndef USE_KOS
void OTExtensionWithMatrix::soft_sender(size_t n)
{
if (not (ot_role & SENDER))
return;
osuCrypto::PRNG prng(osuCrypto::sysRandomSeed());
osuCrypto::SoftSpokenOT::TwoOneMaliciousSender sender(2);
vector<osuCrypto::block> outputs;
for (auto& x : G_receiver)
{
outputs.push_back(x.get_doubleword());
}
sender.setBaseOts(outputs,
{baseReceiverInput.get_ptr(), sender.baseOtCount()}, prng,
*channel);
// Choose which messages should be sent.
auto sendMessages = osuCrypto::allocAlignedBlockArray<std::array<osuCrypto::block, 2>>(n);
// Send the messages.
sender.send(gsl::span(sendMessages.get(), n), prng, *channel);
for (size_t i = 0; i < n; i++)
for (int j = 0; j < 2; j++)
senderOutputMatrices[j].squares.at(i / 128).rows[i % 128] =
sendMessages[i][j];
}
void OTExtensionWithMatrix::soft_receiver(size_t n,
const BitVector& newReceiverInput)
{
if (not (ot_role & RECEIVER))
return;
osuCrypto::PRNG prng(osuCrypto::sysRandomSeed());
osuCrypto::SoftSpokenOT::TwoOneMaliciousReceiver recver(2);
vector<array<osuCrypto::block, 2>> inputs;
for (auto& x : G_sender)
{
inputs.push_back({});
for (int i = 0; i < 2; i++)
inputs.back()[i] = x[i].get_doubleword();
}
recver.setBaseOts(inputs, prng, *channel);
// Choose which messages should be received.
osuCrypto::BitVector choices(n);
assert (n == newReceiverInput.size());
for (size_t i = 0; i < n; i++)
choices[i] = newReceiverInput.get_bit(i);
// Receive the messages
std::vector<osuCrypto::block, osuCrypto::AlignedBlockAllocator> messages(n);
recver.receive(choices, messages, prng, *channel);
for (size_t i = 0; i < n; i++)
{
receiverOutputMatrix.squares.at(i / 128).rows[i % 128] = messages[i];
}
}
#endif
void OTExtensionWithMatrix::extend_correlated(const BitVector& newReceiverInput)
{
extend_correlated(newReceiverInput.size(), newReceiverInput);

View File

@@ -10,6 +10,11 @@
#include "BitMatrix.h"
#include "Math/gf2n.h"
#ifndef USE_KOS
#include <libOTe/TwoChooseOne/SoftSpokenOT/TwoOneMalicious.h>
#include <cryptoTools/Network/IOService.h>
#endif
template <class U>
class OTCorrelator : public OTExtension
{
@@ -47,8 +52,17 @@ public:
class OTExtensionWithMatrix : public OTCorrelator<BitMatrix>
{
static bool warned;
int nsubloops;
#ifndef USE_KOS
static osuCrypto::IOService ios;
osuCrypto::Channel* channel;
#endif
bool agreed;
public:
PRNG G;
@@ -63,12 +77,20 @@ public:
nsubloops(nsubloops)
{
G.ReSeed();
agreed = false;
#ifndef USE_KOS
channel = 0;
#endif
}
OTExtensionWithMatrix(BaseOT& baseOT, TwoPartyPlayer* player, bool passive);
~OTExtensionWithMatrix();
void protocol_agreement();
void transfer(int nOTs, const BitVector& receiverInput, int nloops);
void extend(int nOTs, BitVector& newReceiverInput);
void extend(int nOTs, const BitVector& newReceiverInput);
void extend_correlated(const BitVector& newReceiverInput);
void extend_correlated(int nOTs, const BitVector& newReceiverInput);
void transpose(int start = 0, int slice = -1);
@@ -77,6 +99,10 @@ public:
void hash_outputs(int nOTs, vector<V>& senderOutput, V& receiverOutput,
bool correlated = true);
// SoftSpokenOT
void soft_sender(size_t nOTs);
void soft_receiver(size_t nOTs, const BitVector& newReceiverInput);
void print(BitVector& newReceiverInput, int i = 0);
template <class T>
void print_receiver(BitVector& newReceiverInput, BitMatrix& matrix, int i = 0, int offset = 0);

View File

@@ -54,7 +54,7 @@ class OTMultiplier : public OTMultiplierMac<typename T::sacri_type, typename T::
{
protected:
BitVector keyBits;
vector< vector<BitVector> > senderOutput;
vector< array<BitVector, 2> > senderOutput;
vector<BitVector> receiverOutput;
void multiplyForTriples();
@@ -63,7 +63,7 @@ protected:
virtual void after_correlation() = 0;
virtual void init_authenticator(const BitVector& baseReceiverInput,
const vector< vector<BitVector> >& baseSenderInput,
const vector< array<BitVector, 2> >& baseSenderInput,
const vector<BitVector>& baseReceiverOutput) = 0;
public:
@@ -84,7 +84,7 @@ class MascotMultiplier : public OTMultiplier<T>
OTCorrelator<Matrix<typename T::Square> > auth_ot_ext;
void after_correlation();
void init_authenticator(const BitVector& baseReceiverInput,
const vector< vector<BitVector> >& baseSenderInput,
const vector< array<BitVector, 2> >& baseSenderInput,
const vector<BitVector>& baseReceiverOutput);
void multiplyForBits();
@@ -108,7 +108,7 @@ class TinyMultiplier : public OTMultiplier<T>
void after_correlation();
void init_authenticator(const BitVector& baseReceiverInput,
const vector< vector<BitVector> >& baseSenderInput,
const vector< array<BitVector, 2> >& baseSenderInput,
const vector<BitVector>& baseReceiverOutput);
public:
@@ -126,7 +126,7 @@ class TinierMultiplier : public OTMultiplier<T>
void after_correlation();
void init_authenticator(const BitVector& baseReceiverInput,
const vector< vector<BitVector> >& baseSenderInput,
const vector< array<BitVector, 2> >& baseSenderInput,
const vector<BitVector>& baseReceiverOutput);
public:
@@ -146,7 +146,7 @@ class Spdz2kMultiplier: public OTMultiplier<Spdz2kShare<K, S>>
void after_correlation();
void init_authenticator(const BitVector& baseReceiverInput,
const vector< vector<BitVector> >& baseSenderInput,
const vector< array<BitVector, 2> >& baseSenderInput,
const vector<BitVector>& baseReceiverOutput);
void multiplyForInputs(MultJob job);
@@ -173,10 +173,12 @@ class SemiMultiplier : public OTMultiplier<T>
throw not_implemented();
}
void multiplyForBits();
void after_correlation();
void init_authenticator(const BitVector& baseReceiverInput,
const vector< vector<BitVector> >& baseSenderInput,
const vector< array<BitVector, 2> >& baseSenderInput,
const vector<BitVector>& baseReceiverOutput)
{
(void) baseReceiverInput, (void) baseReceiverOutput, (void) baseSenderInput;

View File

@@ -96,7 +96,6 @@ void OTMultiplier<T>::multiply()
senderOutput.resize(keyBits.size());
for (size_t j = 0; j < keyBits.size(); j++)
{
senderOutput[j].resize(2);
for (int i = 0; i < 2; i++)
{
senderOutput[j][i].resize(128);
@@ -136,6 +135,59 @@ void OTMultiplier<T>::multiply()
}
}
template<class T>
void SemiMultiplier<T>::multiplyForBits()
{
auto& rot_ext = this->rot_ext;
auto& otCorrelator = this->otCorrelator;
OT_ROLE role;
if (this->generator.players[0]->my_num())
role = SENDER;
else
role = RECEIVER;
rot_ext.set_role(INV_ROLE(role));
otCorrelator.set_role(role);
BitVector aBits = this->generator.valueBits[0];
rot_ext.extend_correlated(aBits);
typedef typename T::Rectangle X;
vector<Matrix<X> >& baseSenderOutputs = otCorrelator.matrices;
Matrix<X>& baseReceiverOutput = otCorrelator.senderOutputMatrices[0];
rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput);
otCorrelator.setup_for_correlation(aBits, baseSenderOutputs,
baseReceiverOutput);
otCorrelator.correlate(0, otCorrelator.receiverOutputMatrix.squares.size(),
this->generator.valueBits[0], false, -1);
c_output.clear();
for (unsigned j = 0; j < aBits.size(); j++)
{
int outer = j / X::n_rows();
int inner = j % X::n_rows();
if (role == RECEIVER)
c_output.push_back(
typename T::open_type()
- otCorrelator.receiverOutputMatrix.squares.at(
outer).rows[inner]);
else
c_output.push_back(
otCorrelator.senderOutputMatrices[0].squares.at(outer).rows[inner]);
}
rot_ext.set_role(BOTH);
otCorrelator.set_role(BOTH);
this->outbox.push({});
}
template<class W>
void OTMultiplier<W>::multiplyForTriples()
{
@@ -152,6 +204,8 @@ void OTMultiplier<W>::multiplyForTriples()
auto& outbox = this->outbox;
outbox.push(job);
bool corr_hash = generator.machine.use_extension;
for (int i = 0; i < generator.nloops; i++)
{
this->inbox.pop(job);
@@ -159,7 +213,12 @@ void OTMultiplier<W>::multiplyForTriples()
//timers["Extension"].start();
if (generator.machine.use_extension)
{
#ifdef USE_KOS
rot_ext.extend_correlated(aBits);
#else
rot_ext.extend(aBits.size(), aBits);
corr_hash = false;
#endif
}
else
{
@@ -175,8 +234,9 @@ void OTMultiplier<W>::multiplyForTriples()
bot.sender_inputs[i][j].get_int128(0).a;
}
}
rot_ext.hash_outputs(aBits.size(), baseSenderOutputs,
baseReceiverOutput, generator.machine.use_extension);
baseReceiverOutput, corr_hash);
//timers["Extension"].stop();
//timers["Correlation"].start();
@@ -194,14 +254,14 @@ void OTMultiplier<W>::multiplyForTriples()
template <class T>
void MascotMultiplier<T>::init_authenticator(const BitVector& keyBits,
const vector< vector<BitVector> >& senderOutput,
const vector< array<BitVector, 2> >& senderOutput,
const vector<BitVector>& receiverOutput) {
this->auth_ot_ext.init(keyBits, senderOutput, receiverOutput);
}
template<class T>
void TinyMultiplier<T>::init_authenticator(const BitVector& keyBits,
const vector<vector<BitVector> >& senderOutput,
const vector<array<BitVector, 2> >& senderOutput,
const vector<BitVector>& receiverOutput)
{
mac_vole.init(keyBits, senderOutput, receiverOutput);
@@ -209,7 +269,7 @@ void TinyMultiplier<T>::init_authenticator(const BitVector& keyBits,
template <class T>
void TinierMultiplier<T>::init_authenticator(const BitVector& keyBits,
const vector< vector<BitVector> >& senderOutput,
const vector< array<BitVector, 2> >& senderOutput,
const vector<BitVector>& receiverOutput)
{
auto tmpBits = keyBits;
@@ -219,7 +279,6 @@ void TinierMultiplier<T>::init_authenticator(const BitVector& keyBits,
SeededPRNG G;
for (auto& x : tmpSenderOutput)
{
x.resize(2);
for (auto& y : x)
if (y.size() == 0)
{
@@ -236,7 +295,7 @@ void TinierMultiplier<T>::init_authenticator(const BitVector& keyBits,
template <int K, int S>
void Spdz2kMultiplier<K, S>::init_authenticator(const BitVector& keyBits,
const vector< vector<BitVector> >& senderOutput,
const vector< array<BitVector, 2> >& senderOutput,
const vector<BitVector>& receiverOutput) {
this->mac_vole->init(keyBits, senderOutput, receiverOutput);
input_mac_vole->init(keyBits, senderOutput, receiverOutput);
@@ -426,7 +485,7 @@ void MascotMultiplier<T>::multiplyForBits(true_type)
BitVector extKeyBits = this->keyBits;
extKeyBits.resize_zero(128);
auto extSenderOutput = this->senderOutput;
extSenderOutput.resize(128, {2, BitVector(128)});
extSenderOutput.resize(128, {{2, BitVector(128)}});
SeededPRNG G;
for (auto& x : extSenderOutput)
for (auto& y : x)

View File

@@ -22,7 +22,7 @@ class OTTripleSetup
public:
map<string,Timer> timers;
vector<OffsetPlayer*> players;
vector< vector< vector<BitVector> > > baseSenderInputs;
vector< vector< array<BitVector, 2> > > baseSenderInputs;
vector< vector<BitVector> > baseReceiverOutputs;
int get_nparties() const { return nparties; }
@@ -82,5 +82,28 @@ public:
OTTripleSetup get_fresh();
};
class OnDemandOTTripleSetup
{
OTTripleSetup* setup;
public:
OnDemandOTTripleSetup() :
setup(0)
{
}
~OnDemandOTTripleSetup()
{
if (setup)
delete setup;
}
OTTripleSetup get_fresh(Player& P)
{
if (not setup)
setup = new OTTripleSetup(P, true);
return setup->get_fresh();
}
};
#endif

View File

@@ -47,6 +47,8 @@ public:
Rectangle<U, V>& sub(const void* other) { return sub_(other); }
Rectangle<U, V>& sub_(const void* other);
void bit_sub(const BitVector& bits, int start);
void mul(const BitVector& a, const V& b);
void randomize(PRNG& G);

View File

@@ -58,6 +58,13 @@ Rectangle<U, V>& Rectangle<U, V>::sub_(const void* other)
return *this;
}
template<class U, class V>
void Rectangle<U, V>::bit_sub(const BitVector& bits, int start)
{
for (int i = 0; i < N_ROWS; i++)
rows[i] = rows[i] - bits.get_bit(start + i);
}
template<class U, class V>
void Rectangle<U, V>::mul(const BitVector& a, const V& b)
{

View File

@@ -11,6 +11,7 @@
#include "Math/Z2k.h"
#include "OT/OTTripleSetup.h"
#include "OT/MascotParams.h"
#include "Tools/NetworkOptions.h"
class GeneratorThread;
@@ -25,11 +26,15 @@ class TripleMachine : public OfflineMachineBase, public MascotParams
bigint prime;
public:
Player* player;
int nloops;
bool bonding;
int z2k, z2s;
NetworkOptionsWithNumber network_opts;
public:
TripleMachine(int argc, const char** argv);
template<class T>

View File

@@ -6,6 +6,7 @@
#include "BaseMachine.h"
#include "OnlineOptions.h"
#include "Math/Setup.h"
#include "Tools/Bundle.h"
#include <iostream>
#include <sodium.h>
@@ -13,6 +14,7 @@ using namespace std;
BaseMachine* BaseMachine::singleton = 0;
thread_local int BaseMachine::thread_num;
thread_local OnDemandOTTripleSetup BaseMachine::ot_setup;
void print_usage(ostream& o, const char* name, size_t capacity)
{
@@ -126,12 +128,12 @@ void BaseMachine::stop(int n)
void BaseMachine::print_timers()
{
cerr << "The following timing is ";
cerr << "The following benchmarks are ";
if (OnlineOptions::singleton.live_prep)
cerr << "in";
else
cerr << "ex";
cerr << "clusive preprocessing." << endl;
cerr << "cluding preprocessing (offline phase)." << endl;
cerr << "Time = " << timer[0].elapsed() << " seconds " << endl;
timer.erase(0);
for (auto it = timer.begin(); it != timer.end(); it++)
@@ -196,3 +198,14 @@ void BaseMachine::set_thread_comm(const NamedCommStats& stats)
assert(queue);
queue->set_comm_stats(stats);
}
void BaseMachine::print_global_comm(Player& P, const NamedCommStats& stats)
{
Bundle<octetStream> bundle(P);
bundle.mine.store(stats.sent);
P.Broadcast_Receive_no_stats(bundle);
size_t global = 0;
for (auto& os : bundle)
global += os.get_int(8);
cerr << "Global data sent = " << global / 1e6 << " MB (all parties)" << endl;
}

View File

@@ -23,6 +23,8 @@ class BaseMachine
protected:
static BaseMachine* singleton;
static thread_local OnDemandOTTripleSetup ot_setup;
std::map<int,TimerWithComm> timer;
string compiler;
@@ -39,8 +41,6 @@ public:
string progname;
int nthreads;
vector<OTTripleSetup> ot_setups;
ThreadQueues queues;
vector<string> bc_filenames;
@@ -71,14 +71,13 @@ public:
NamedCommStats total_comm();
void set_thread_comm(const NamedCommStats& stats);
void print_global_comm(Player& P, const NamedCommStats& stats);
};
inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P)
{
if (singleton and size_t(thread_num) < s().ot_setups.size())
return s().ot_setups.at(thread_num).get_fresh();
else
return OTTripleSetup(P, true);
return ot_setup.get_fresh(P);
}
#endif /* PROCESSOR_BASEMACHINE_H_ */

View File

@@ -33,9 +33,9 @@ void BaseInstruction::parse(istream& s, int inst_pos)
r[0]=0; r[1]=0; r[2]=0; r[3]=0;
int pos=s.tellg();
opcode=get_int(s);
size=unsigned(opcode)>>10;
opcode&=0x3FF;
uint64_t code = get_long(s);
size = code >> 10;
opcode = 0x3FF & code;
if (size==0)
size=1;
@@ -288,8 +288,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
get_vector(2, start, s);
break;
// open instructions + read/write instructions with variable length args
case OPEN:
case GOPEN:
case MULS:
case GMULS:
case MULRS:
@@ -475,6 +473,8 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
n = get_int(s);
get_vector(4, start, s);
break;
case OPEN:
case GOPEN:
case TRANS:
num_var_args = get_int(s) - 1;
n = get_int(s);
@@ -1013,6 +1013,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Proc.Procp.send_personal(start);
return;
case PRIVATEOUTPUT:
Proc.Procp.check();
Proc.Procp.private_output(start);
return;
// Note: Fp version has different semantics for NOTC than GNOTC
@@ -1032,10 +1033,10 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
sgf2n::shrsi(Proc2, *this);
return;
case OPEN:
Proc.Procp.POpen(start, Proc.P, size);
Proc.Procp.POpen(*this);
return;
case GOPEN:
Proc.Proc2.POpen(start, Proc.P, size);
Proc.Proc2.POpen(*this);
return;
case MULS:
Proc.Procp.muls(start, size);

View File

@@ -85,6 +85,10 @@ Machine<sint, sgf2n>::Machine(Names& playerNames, bool use_encryption,
sint::LivePrep::basic_setup(*P);
}
sint::MAC_Check::setup(*P);
sint::bit_type::MAC_Check::setup(*P);
sgf2n::MAC_Check::setup(*P);
alphapi = read_generate_write_mac_key<sint>(*P);
alpha2i = read_generate_write_mac_key<sgf2n>(*P);
alphabi = read_generate_write_mac_key<typename
@@ -153,13 +157,6 @@ void Machine<sint, sgf2n>::prepare(const string& progname_str)
progs[0].print_offline_cost();
#endif
if (live_prep
and (sint::needs_ot or sgf2n::needs_ot or sint::bit_type::needs_ot))
{
for (int i = old_n_threads; i < nthreads; i++)
ot_setups.push_back({ *P, true });
}
/* Set up the threads */
tinfo.resize(nthreads);
threads.resize(nthreads);
@@ -192,6 +189,10 @@ Machine<sint, sgf2n>::~Machine()
sint::LivePrep::teardown();
sgf2n::LivePrep::teardown();
sint::MAC_Check::teardown();
sint::bit_type::MAC_Check::teardown();
sgf2n::MAC_Check::teardown();
delete P;
for (auto& queue : queues)
delete queue;
@@ -475,13 +476,7 @@ void Machine<sint, sgf2n>::run(const string& progname)
cerr << ")" << endl;
auto& P = *this->P;
Bundle<octetStream> bundle(P);
bundle.mine.store(comm_stats.sent);
P.Broadcast_Receive_no_stats(bundle);
size_t global = 0;
for (auto& os : bundle)
global += os.get_int(8);
cerr << "Global data sent = " << global / 1e6 << " MB (all parties)" << endl;
this->print_global_comm(P, comm_stats);
#ifdef VERBOSE_OPTIONS
if (opening_sum < N.num_players() && !direct)

View File

@@ -15,7 +15,7 @@ public:
NoFilePrep(int, int, const string&, DataPositions& usage, int = -1) :
Preprocessing<T>(usage)
{
throw runtime_error("don't call this");
throw runtime_error("preprocessing from file disabled");
}
};

View File

@@ -21,9 +21,15 @@ OfflineMachine<W>::OfflineMachine(int argc, const char** argv,
machine.load_schedule(online_opts.progname, false);
Program program(playerNames.num_players());
program.parse(machine.bc_filenames[0]);
if (program.usage_unknown())
{
cerr << "Preprocessing might be insufficient "
<< "due to unknown requirements" << endl;
}
usage = program.get_offline_data_used();
n_threads = machine.nthreads;
machine.ot_setups.push_back({P});
}
template<class W>
@@ -47,12 +53,20 @@ int OfflineMachine<W>::run()
// setup before generation to fix prime
T::LivePrep::basic_setup(P);
T::MAC_Check::setup(P);
T::bit_type::MAC_Check::setup(P);
U::MAC_Check::setup(P);
generate<T>();
generate<typename T::bit_type::part_type>();
generate<U>();
thread.MC->Check(P);
T::MAC_Check::teardown();
T::bit_type::MAC_Check::teardown();
U::MAC_Check::teardown();
return 0;
}

View File

@@ -268,6 +268,9 @@ void thread_info<sint, sgf2n>::Sub_Main_Func()
// Execute the program
progs[program].execute(Proc);
// make sure values used in other threads are safe
Proc.check();
// prevent mangled output
cout.flush();

View File

@@ -38,6 +38,11 @@ public:
int run();
Player* new_player(const string& id_base);
int get_lg2()
{
return lg2;
}
};
class DishonestMajorityMachine : public OnlineMachine

View File

@@ -14,16 +14,6 @@
using namespace std;
template<class T, class U>
int spdz_main(int argc, const char** argv, ez::ezOptionParser& opt, bool live_prep_default = true)
{
OnlineOptions& online_opts = OnlineOptions::singleton;
online_opts = {opt, argc, argv, T(), live_prep_default};
DishonestMajorityMachine machine(argc, argv, opt, online_opts, typename U::clear());
return machine.run<T, U>();
}
template<class V>
OnlineMachine::OnlineMachine(int argc, const char** argv, ez::ezOptionParser& opt,
OnlineOptions& online_opts, int nplayers, V) :

View File

@@ -346,17 +346,14 @@ void OnlineOptions::finalize(ez::ezOptionParser& opt, int argc,
prime = schedule_prime;
}
// ignore program if length explicitly set from command line
if (opt.get("-lgp") and not opt.isSet("-lgp"))
{
int prog_lgp = BaseMachine::prime_length_from_schedule(progname);
prog_lgp = DIV_CEIL(prog_lgp, 64) * 64;
if (prog_lgp != 0)
// only increase to be consistent with program not demanding any length
if (prog_lgp > lgp)
lgp = prog_lgp;
#ifndef FEWER_PRIMES
if (prime_limbs() > 4)
#endif
lgp = max(lgp, gfp0::MAX_N_BITS);
}
set_trunc_error(opt);

View File

@@ -61,8 +61,10 @@ public:
ArithmeticProcessor* Proc = 0);
~SubProcessor();
void check();
// Access to PO (via calls to POpen start/stop)
void POpen(const vector<int>& reg,const Player& P,int size);
void POpen(const Instruction& inst);
void muls(const vector<int>& reg, int size);
void mulrs(const vector<int>& reg);

Some files were not shown because too many files have changed in this diff Show More