Python 3, semi-honest computation using semi-homomorphic encryption.

This commit is contained in:
Marcel Keller
2019-11-21 17:23:51 +11:00
parent 3bf45ebbaf
commit 470b075803
138 changed files with 1932 additions and 1022 deletions

View File

@@ -133,7 +133,7 @@ void RealGarbleWire<T>::input(party_id_t from, char input)
protocol.init_mul(party.shared_proc);
protocol.prepare_mul(mask, T(1, party.P->my_num(), party.mac_key) - mask);
protocol.exchange();
if (party.MC->POpen(protocol.finalize_mul(), *party.P) != 0)
if (party.MC->open(protocol.finalize_mul(), *party.P) != 0)
throw runtime_error("input mask not a bit");
}
#ifdef DEBUG_MASK
@@ -168,7 +168,7 @@ void RealGarbleWire<T>::output()
auto& party = RealProgramParty<T>::s();
assert(party.MC != 0);
assert(party.P != 0);
auto m = party.MC->POpen(mask, *party.P);
auto m = party.MC->open(mask, *party.P);
party.output_masks.push_back(m.get_bit(0));
party.taint();
#ifdef DEBUG_MASK

View File

@@ -90,7 +90,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
{
mac_key.randomize(prng);
if (T::needs_ot)
BaseMachine::s().ot_setups.push_back({{{*P, true}}});
BaseMachine::s().ot_setups.push_back({*P, true});
prep = Preprocessing<T>::get_live_prep(0, usage);
}
else

View File

@@ -1,5 +1,11 @@
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
## 0.1.3 (Nov 21, 2019)
- Python 3
- Semi-honest computation based on semi-homomorphic encryption
- Access to player information in high-level language
## 0.1.2 (Oct 11, 2019)
- Machine learning capabilities used for [MobileNets inference](https://eprint.iacr.org/2019/131) and the iDASH submission

View File

@@ -5,6 +5,6 @@ class Program(object):
types.program = self
instructions.program = self
self.curr_tape = None
execfile(progname)
exec(compile(open(progname).read(), progname, 'exec'))
def malloc(self, *args):
pass

View File

@@ -5,6 +5,7 @@ from Compiler.exceptions import *
from Compiler import util, oram, floatingpoint, library
import Compiler.GC.instructions as inst
import operator
from functools import reduce
class bits(Tape.Register, _structure):
n = 40
@@ -82,7 +83,7 @@ class bits(Tape.Register, _structure):
cls.load_inst[util.is_constant(address)](res, address)
return res
def store_in_mem(self, address):
self.store_inst[isinstance(address, (int, long))](self, address)
self.store_inst[isinstance(address, int)](self, address)
def __init__(self, value=None, n=None, size=None):
if size != 1 and size is not None:
raise Exception('invalid size for bit type: %s' % size)
@@ -92,11 +93,11 @@ class bits(Tape.Register, _structure):
self.load_other(value)
def set_length(self, n):
if n > self.max_length:
print self.max_length
print(self.max_length)
raise Exception('too long: %d' % n)
self.n = n
def load_other(self, other):
if isinstance(other, (int, long)):
if isinstance(other, int):
self.set_length(self.n or util.int_len(other))
self.load_int(other)
elif isinstance(other, regint):
@@ -115,6 +116,7 @@ class bits(Tape.Register, _structure):
def __repr__(self):
return '%s(%d/%d)' % \
(super(bits, self).__repr__(), self.n, type(self).n)
__str__ = __repr__
class cbits(bits):
max_length = 64
@@ -219,13 +221,13 @@ class sbits(bits):
@classmethod
def load_dynamic_mem(cls, address):
res = cls()
if isinstance(address, (long, int)):
if isinstance(address, int):
inst.ldmsd(res, address, cls.n)
else:
inst.ldmsdi(res, address, cls.n)
return res
def store_in_dynamic_mem(self, address):
if isinstance(address, (long, int)):
if isinstance(address, int):
inst.stmsd(self, address)
else:
inst.stmsdi(self, cbits.conv(address))
@@ -322,7 +324,7 @@ class sbits(bits):
mul_bits = [self if b else zero for b in bits]
return self.bit_compose(mul_bits)
else:
print self.n, other
print(self.n, other)
return NotImplemented
def __lshift__(self, i):
return self.bit_compose([sbit(0)] * i + self.bit_decompose()[:self.max_length-i])
@@ -478,7 +480,7 @@ class bitsBlock(oram.Block):
self.start_demux = oram.demux_list(self.start_bits)
self.entries = [sbits.bit_compose(self.value_bits[i*length:][:length]) \
for i in range(entries_per_block)]
self.mul_entries = map(operator.mul, self.start_demux, self.entries)
self.mul_entries = list(map(operator.mul, self.start_demux, self.entries))
self.bits = sum(self.mul_entries).bit_decompose()
self.mul_value = sbits.compose(self.mul_entries, sum(self.lengths))
self.anti_value = self.mul_value + self.value
@@ -662,6 +664,12 @@ class sbitfix(_fix):
return super(sbitfix, self).__mul__(other)
__rxor__ = __xor__
__rmul__ = __mul__
@staticmethod
def multipliable(other, k, f):
class cls(_fix):
int_type = sbitint.get_type(k)
cls.set_precision(f, k)
return cls._new(cls.int_type(other), k, f)
sbitfix.set_precision(20, 41)

View File

@@ -1,15 +1,15 @@
import compilerLib, program, instructions, types, library, floatingpoint
import GC.types
from . import compilerLib, program, instructions, types, library, floatingpoint
from .GC import types as GC_types
import inspect
from config import *
from compilerLib import run
from .config import *
from .compilerLib import run
# add all instructions to the program VARS dictionary
compilerLib.VARS = {}
instr_classes = [t[1] for t in inspect.getmembers(instructions, inspect.isclass)]
for mod in (types, GC.types):
for mod in (types, GC_types):
instr_classes += [t[1] for t in inspect.getmembers(mod, inspect.isclass)\
if t[1].__module__ == mod.__name__]

View File

@@ -10,16 +10,17 @@ import Compiler.program
import heapq, itertools
import operator
import sys
from functools import reduce
class StraightlineAllocator:
"""Allocate variables in a straightline program using n registers.
It is based on the precondition that every register is only defined once."""
def __init__(self, n):
self.alloc = {}
self.alloc = dict_by_id()
self.usage = Compiler.program.RegType.create_dict(lambda: 0)
self.defined = {}
self.dealloc = set()
self.defined = dict_by_id()
self.dealloc = set_by_id()
self.n = n
def alloc_reg(self, reg, free):
@@ -77,8 +78,8 @@ class StraightlineAllocator:
unused_regs.append(j)
if unused_regs and len(unused_regs) == len(list(i.get_def())):
# only report if all assigned registers are unused
print "Register(s) %s never used, assigned by '%s' in %s" % \
(unused_regs,i,format_trace(i.caller))
print("Register(s) %s never used, assigned by '%s' in %s" % \
(unused_regs,i,format_trace(i.caller)))
for j in i.get_used():
self.alloc_reg(j, alloc_pool)
@@ -86,7 +87,7 @@ class StraightlineAllocator:
self.dealloc_reg(j, i, alloc_pool)
if k % 1000000 == 0 and k > 0:
print "Allocated registers for %d instructions at" % k, time.asctime()
print("Allocated registers for %d instructions at" % k, time.asctime())
# print "Successfully allocated registers"
# print "modp usage: %d clear, %d secret" % \
@@ -97,8 +98,8 @@ class StraightlineAllocator:
def determine_scope(block, options):
last_def = defaultdict(lambda: -1)
used_from_scope = set()
last_def = defaultdict_by_id(lambda: -1)
used_from_scope = set_by_id()
def find_in_scope(reg, scope):
while True:
@@ -114,18 +115,18 @@ def determine_scope(block, options):
used_from_scope.add(reg)
reg.can_eliminate = False
else:
print 'Warning: read before write at register', reg
print '\tline %d: %s' % (n, instr)
print '\tinstruction trace: %s' % format_trace(instr.caller, '\t\t')
print '\tregister trace: %s' % format_trace(reg.caller, '\t\t')
print('Warning: read before write at register', reg)
print('\tline %d: %s' % (n, instr))
print('\tinstruction trace: %s' % format_trace(instr.caller, '\t\t'))
print('\tregister trace: %s' % format_trace(reg.caller, '\t\t'))
if options.stop:
sys.exit(1)
def write(reg, n):
if last_def[reg] != -1:
print 'Warning: double write at register', reg
print '\tline %d: %s' % (n, instr)
print '\ttrace: %s' % format_trace(instr.caller, '\t\t')
print('Warning: double write at register', reg)
print('\tline %d: %s' % (n, instr))
print('\ttrace: %s' % format_trace(instr.caller, '\t\t'))
if options.stop:
sys.exit(1)
last_def[reg] = n
@@ -146,7 +147,7 @@ def determine_scope(block, options):
write(reg, n)
block.used_from_scope = used_from_scope
block.defined_registers = set(last_def.iterkeys())
block.defined_registers = set_by_id(last_def.keys())
class Merger:
def __init__(self, block, options, merge_classes):
@@ -178,7 +179,7 @@ class Merger:
if inst.is_vec():
for arg in inst.args:
arg.create_vector_elements()
res = sum(zip(*inst.args), ())
res = sum(list(zip(*inst.args)), ())
return list(res)
else:
return inst.args
@@ -241,7 +242,7 @@ class Merger:
remaining_input_nodes = []
def do_merge(nodes):
if len(nodes) > 1000:
print 'Merging %d inputs...' % len(nodes)
print('Merging %d inputs...' % len(nodes))
self.do_merge(iter(nodes))
for n in self.input_nodes:
inst = self.instructions[n]
@@ -252,7 +253,7 @@ class Merger:
if len(merge) >= self.max_parallel_open:
do_merge(merge)
merge[:] = []
for merge in reversed(sorted(merges.itervalues())):
for merge in reversed(sorted(merges.values())):
if merge:
do_merge(merge)
self.input_nodes = remaining_input_nodes
@@ -266,7 +267,7 @@ class Merger:
instructions = self.instructions
flex_nodes = defaultdict(dict)
starters = []
for n in xrange(len(G)):
for n in range(len(G)):
if n not in merge_nodes_set and \
depth_of[n] != rev_depth_of[n] and G[n] and G.get_attr(n,'start') == -1 and not isinstance(instructions[n], AsymmetricCommunicationInstruction):
#print n, depth_of[n], rev_depth_of[n]
@@ -275,19 +276,19 @@ class Merger:
not isinstance(self.instructions[n], RawInputInstruction):
starters.append(n)
if n % 10000000 == 0 and n > 0:
print "Processed %d nodes at" % n, time.asctime()
print("Processed %d nodes at" % n, time.asctime())
inputs = defaultdict(list)
for node in self.input_nodes:
player = self.instructions[node].args[0]
inputs[player].append(node)
first_inputs = [l[0] for l in inputs.itervalues()]
first_inputs = [l[0] for l in inputs.values()]
other_inputs = []
i = 0
while True:
i += 1
found = False
for l in inputs.itervalues():
for l in inputs.values():
if i < len(l):
other_inputs.append(l[i])
found = True
@@ -299,20 +300,20 @@ class Merger:
# magical preorder for topological search
max_depth = max(merges)
if max_depth > 10000:
print "Computing pre-ordering ..."
for i in xrange(max_depth, 0, -1):
print("Computing pre-ordering ...")
for i in range(max_depth, 0, -1):
preorder.append(G.get_attr(merges[i], 'stop'))
for j in flex_nodes[i-1].itervalues():
for j in flex_nodes[i-1].values():
preorder.extend(j)
preorder.extend(flex_nodes[0].get(i, []))
preorder.append(merges[i])
if i % 100000 == 0 and i > 0:
print "Done level %d at" % i, time.asctime()
print("Done level %d at" % i, time.asctime())
preorder.extend(other_inputs)
preorder.extend(starters)
preorder.extend(first_inputs)
if max_depth > 10000:
print "Done at", time.asctime()
print("Done at", time.asctime())
return preorder
def longest_paths_merge(self):
@@ -343,8 +344,8 @@ class Merger:
t = type(self.instructions[merge[0]])
self.counter[t] += len(merge)
if len(merge) > 1000:
print 'Merging %d %s in round %d/%d' % \
(len(merge), t.__name__, i, len(merges))
print('Merging %d %s in round %d/%d' % \
(len(merge), t.__name__, i, len(merges)))
self.do_merge(merge)
self.merge_inputs()
@@ -352,11 +353,11 @@ class Merger:
preorder = None
if len(instructions) > 100000:
print "Topological sort ..."
print("Topological sort ...")
order = Compiler.graph.topological_sort(G, preorder)
instructions[:] = [instructions[i] for i in order if instructions[i] is not None]
if len(instructions) > 100000:
print "Done at", time.asctime()
print("Done at", time.asctime())
return len(merges)
@@ -377,7 +378,7 @@ class Merger:
self.G = G
reg_nodes = {}
last_def = defaultdict(lambda: -1)
last_def = defaultdict_by_id(lambda: -1)
last_mem_write = []
last_mem_read = []
warned_about_mem = []
@@ -411,8 +412,8 @@ class Merger:
def handle_mem_access(addr, reg_type, last_access_this_kind,
last_access_other_kind):
this = last_access_this_kind[addr,reg_type]
other = last_access_other_kind[addr,reg_type]
this = last_access_this_kind[str(addr),reg_type]
other = last_access_other_kind[str(addr),reg_type]
if this and other:
if this[-1] < other[0]:
del this[:]
@@ -429,15 +430,15 @@ class Merger:
handle_mem_access(addr_i, reg_type, last_access_this_kind,
last_access_other_kind)
if not warned_about_mem and (instr.get_size() > 100):
print 'WARNING: Order of memory instructions ' \
'not preserved due to long vector, errors possible'
print('WARNING: Order of memory instructions ' \
'not preserved due to long vector, errors possible')
warned_about_mem.append(True)
else:
handle_mem_access(addr, reg_type, last_access_this_kind,
last_access_other_kind)
if not warned_about_mem and not isinstance(instr, DirectMemoryInstruction):
print 'WARNING: Order of memory instructions ' \
'not preserved, errors possible'
print('WARNING: Order of memory instructions ' \
'not preserved, errors possible')
# hack
warned_about_mem.append(True)
@@ -553,11 +554,11 @@ class Merger:
self.sources.append(n)
if n % 100000 == 0 and n > 0:
print "Processed dependency of %d/%d instructions at" % \
(n, len(block.instructions)), time.asctime()
print("Processed dependency of %d/%d instructions at" % \
(n, len(block.instructions)), time.asctime())
if len(open_nodes) > 1000:
print "Program has %d %s instructions" % (len(open_nodes), merge_classes)
print("Program has %d %s instructions" % (len(open_nodes), merge_classes))
def merge_nodes(self, i, j):
""" Merge node j into i, removing node j """
@@ -566,8 +567,8 @@ class Merger:
G.remove_edge(i, j)
if i in G[j]:
G.remove_edge(j, i)
G.add_edges_from(zip(itertools.cycle([i]), G[j], [G.weights[(j,k)] for k in G[j]]))
G.add_edges_from(zip(G.pred[j], itertools.cycle([i]), [G.weights[(k,j)] for k in G.pred[j]]))
G.add_edges_from(list(zip(itertools.cycle([i]), G[j], [G.weights[(j,k)] for k in G[j]])))
G.add_edges_from(list(zip(G.pred[j], itertools.cycle([i]), [G.weights[(k,j)] for k in G.pred[j]])))
G.get_attr(i, 'merges').append(j)
G.remove_node(j)
@@ -578,7 +579,7 @@ class Merger:
count = 0
open_count = 0
stats = defaultdict(lambda: 0)
for i,inst in zip(xrange(len(instructions) - 1, -1, -1), reversed(instructions)):
for i,inst in zip(range(len(instructions) - 1, -1, -1), reversed(instructions)):
# remove if instruction has result that isn't used
unused_result = not G.degree(i) and len(list(inst.get_def())) \
and reduce(operator.and_, (reg.can_eliminate for reg in inst.get_def())) \
@@ -608,21 +609,21 @@ class Merger:
eliminate(i)
count += 2
if count > 0:
print 'Eliminated %d dead instructions, among which %d opens: %s' \
% (count, open_count, dict(stats))
print('Eliminated %d dead instructions, among which %d opens: %s' \
% (count, open_count, dict(stats)))
def print_graph(self, filename):
f = open(filename, 'w')
print >>f, 'digraph G {'
print('digraph G {', file=f)
for i in range(self.G.n):
for j in self.G[i]:
print >>f, '"%d: %s" -> "%d: %s";' % \
(i, self.instructions[i], j, self.instructions[j])
print >>f, '}'
print('"%d: %s" -> "%d: %s";' % \
(i, self.instructions[i], j, self.instructions[j]), file=f)
print('}', file=f)
f.close()
def print_depth(self, filename):
f = open(filename, 'w')
for i in range(self.G.n):
print >>f, '%d: %s' % (self.depths[i], self.instructions[i])
print('%d: %s' % (self.depths[i], self.instructions[i]), file=f)
f.close()

View File

@@ -26,7 +26,7 @@ def find_deeper(a, b, path, start, length, compute_level=True):
any_empty = OR(a.empty, b.empty)
a_diff = [XOR(a_bits[i], path_bits[i]) for i in range(start, length)]
b_diff = [XOR(b_bits[i], path_bits[i]) for i in range(start, length)]
diff = [XOR(ab, bb) for ab,bb in zip(a_bits, b_bits)[start:length]]
diff = [XOR(ab, bb) for ab,bb in list(zip(a_bits, b_bits))[start:length]]
diff_preor = type(a.value).PreOR([any_empty] + diff)
diff_first = [x - y for x,y in zip(diff_preor, diff_preor[1:])]
winner = sum((ad * df for ad,df in zip(a_diff, diff_first)), a.empty)
@@ -38,7 +38,7 @@ def find_deeper(a, b, path, start, length, compute_level=True):
def find_deepest(paths, search_path, start, length, compute_level=True):
if len(paths) == 1:
return None, paths[0], 1
l = len(paths) / 2
l = len(paths) // 2
_, a, a_index = find_deepest(paths[:l], search_path, start, length, False)
_, b, b_index = find_deepest(paths[l:], search_path, start, length, False)
level, winner = find_deeper(a, b, search_path, start, length, compute_level)
@@ -57,7 +57,7 @@ def greater_unary(a, b):
if len(a) == 1:
return a[0], b[0]
else:
l = len(a) / 2
l = len(a) // 2
return gu_step(greater_unary(a[l:], b[l:]), greater_unary(a[:l], b[:l]))
def comp_step(high, low):
@@ -75,7 +75,7 @@ def comp_binary(a, b):
if len(a) == 1:
return a[0], b[0]
else:
l = len(a) / 2
l = len(a) // 2
return comp_step(comp_binary(a[l:], b[l:]), comp_binary(a[:l], b[:l]))
def unary_to_binary(l):
@@ -89,8 +89,8 @@ class CircuitORAM(PathORAM):
self.D = log2(size)
self.logD = log2(self.D)
self.L = self.D + 1
print 'create oram of size %d with depth %d and %d buckets' \
% (size, self.D, self.n_buckets())
print('create oram of size %d with depth %d and %d buckets' \
% (size, self.D, self.n_buckets()))
self.value_type = value_type
self.index_type = value_type.get_type(self.D)
if entry_size is not None:
@@ -245,7 +245,7 @@ class CircuitORAM(PathORAM):
for i,_ in enumerate(self.recursive_evict_rounds()):
Program.prog.curr_tape.start_new_basicblock(name='circuit-recursive-evict-round-%d-%d' % (i, self.size))
def recursive_evict_rounds(self):
for _ in itertools.izip(self.evict_rounds(), self.index.l.recursive_evict_rounds()):
for _ in zip(self.evict_rounds(), self.index.l.recursive_evict_rounds()):
yield
def bucket_indices_on_path_to(self, leaf):
# root is at 1, different to PathORAM
@@ -272,10 +272,10 @@ threshold = 2**10
def OptimalCircuitORAM(size, value_type, *args, **kwargs):
if size <= threshold:
print size, 'below threshold', threshold
print(size, 'below threshold', threshold)
return LinearORAM(size, value_type, *args, **kwargs)
else:
print size, 'above threshold', threshold
print(size, 'above threshold', threshold)
return RecursiveCircuitORAM(size, value_type, *args, **kwargs)
class RecursiveCircuitIndexStructure(PackedIndexStructure):

View File

@@ -28,8 +28,8 @@ use_inv = True
# (r[i], r[i]^-1, r[i] * r[i-1]^-1)
do_precomp = True
import instructions_base
import util
from . import instructions_base
from . import util
def set_variant(options):
""" Set flags based on the command-line option provided """
@@ -55,7 +55,7 @@ def ld2i(c, n):
""" Load immediate 2^n into clear GF(p) register c """
t1 = program.curr_block.new_reg('c')
ldi(t1, 2 ** (n % 30))
for i in range(n / 30):
for i in range(n // 30):
t2 = program.curr_block.new_reg('c')
mulci(t2, t1, 2 ** 30)
t1 = t2
@@ -75,13 +75,13 @@ def LTZ(s, a, k, kappa):
k: bit length of a
"""
from types import sint
from .types import sint
t = sint()
Trunc(t, a, k, k - 1, kappa, True)
subsfi(s, t, 0)
def LessThanZero(a, k, kappa):
import types
from . import types
res = types.sint()
LTZ(res, a, k, kappa)
return res
@@ -124,7 +124,7 @@ def TruncZeroes(a, k, m, signed):
if program.options.ring:
return TruncLeakyInRing(a, k, m, signed)
else:
import types
from . import types
tmp = types.cint()
inv2m(tmp, m)
return a * tmp
@@ -136,7 +136,7 @@ def TruncLeakyInRing(a, k, m, signed):
"""
assert k > m
assert int(program.options.ring) >= k
from types import sint, intbitint, cint, cgf2n
from .types import sint, intbitint, cint, cgf2n
n_bits = k - m
n_shift = int(program.options.ring) - n_bits
r_bits = [sint.get_random_bit() for i in range(n_bits)]
@@ -165,7 +165,7 @@ def TruncRoundNearest(a, k, m, kappa, signed=False):
# cannot work with bit length k+1
tmp = TruncRing(None, a, k, m - 1, signed)
return TruncRing(None, tmp + 1, k - m + 1, 1, signed)
from types import sint
from .types import sint
res = sint()
Trunc(res, a + (1 << (m - 1)), k + 1, m, kappa, signed)
return res
@@ -277,7 +277,7 @@ def BitLTC1(u, a, b, kappa):
"""
k = len(b)
p = [program.curr_block.new_reg('s') for i in range(k)]
import floatingpoint
from . import floatingpoint
a_bits = floatingpoint.bits(a, k)
if instructions_base.get_global_vector_size() == 1:
a_ = a_bits
@@ -357,12 +357,12 @@ def CarryOutAux(d, a, kappa):
if k > 1 and k % 2 == 1:
a.append(None)
k += 1
u = [None]*(k/2)
u = [None]*(k//2)
a = a[::-1]
if k > 1:
for i in range(k/2):
u[i] = carry(a[2*i+1], a[2*i], i != k/2-1)
CarryOutAux(d, u[:k/2][::-1], kappa)
for i in range(k//2):
u[i] = carry(a[2*i+1], a[2*i], i != k//2-1)
CarryOutAux(d, u[:k//2][::-1], kappa)
else:
movs(d, a[0][1])
@@ -376,7 +376,7 @@ def CarryOut(res, a, b, c=0, kappa=None):
c: initial carry-in bit
"""
k = len(a)
import types
from . import types
d = [program.curr_block.new_reg('s') for i in range(k)]
t = [[types.sint() for i in range(k)] for i in range(4)]
s = [program.curr_block.new_reg('s') for i in range(3)]
@@ -394,7 +394,7 @@ def CarryOut(res, a, b, c=0, kappa=None):
def CarryOutLE(a, b, c=0):
""" Little-endian version """
import types
from . import types
res = types.sint()
CarryOut(res, a[::-1], b[::-1], c)
return res
@@ -407,7 +407,7 @@ def BitLTL(res, a, b, kappa):
b: array of secret bits (same length as a)
"""
k = len(b)
import floatingpoint
from . import floatingpoint
a_bits = floatingpoint.bits(a, k)
s = [[program.curr_block.new_reg('s') for i in range(k)] for j in range(2)]
t = [program.curr_block.new_reg('s') for i in range(1)]
@@ -547,7 +547,7 @@ def KMulC(a):
"""
Return just the product of all items in a
"""
from types import sint, cint
from .types import sint, cint
p = sint()
if use_inv:
PreMulC_with_inverses(p, a)
@@ -582,7 +582,7 @@ def Mod2(a_0, a, k, kappa, signed):
adds(t[2], t[0], t[1])
adds(t[3], t[2], r_prime)
asm_open(c, t[3])
import floatingpoint
from . import floatingpoint
c_0 = floatingpoint.bits(c, 1)[0]
mulci(tc, c_0, 2)
mulm(t[4], r_0, tc)
@@ -591,4 +591,4 @@ def Mod2(a_0, a, k, kappa, signed):
# hack for circular dependency
from instructions import *
from .instructions import *

View File

@@ -1,8 +1,8 @@
from Compiler.program import Program
from Compiler.config import *
from Compiler.exceptions import *
import instructions, instructions_base, types, comparison, library
import GC.types
from . import instructions, instructions_base, types, comparison, library
from .GC import types as GC_types
import random
import time
@@ -25,36 +25,36 @@ def run(args, options, param=-1, merge_opens=True, emulate=True, \
prog.DEBUG = debug
VARS['program'] = prog
if options.binary:
VARS['sint'] = GC.types.sbitint.get_type(int(options.binary))
VARS['sfix'] = GC.types.sbitfix
VARS['sint'] = GC_types.sbitint.get_type(int(options.binary))
VARS['sfix'] = GC_types.sbitfix
comparison.set_variant(options)
print 'Compiling file', prog.infile
print('Compiling file', prog.infile)
# no longer needed, but may want to support assembly in future (?)
if assemblymode:
prog.restart_main_thread()
for i in xrange(INIT_REG_MAX):
for i in range(INIT_REG_MAX):
VARS['c%d'%i] = prog.curr_block.new_reg('c')
VARS['s%d'%i] = prog.curr_block.new_reg('s')
VARS['cg%d'%i] = prog.curr_block.new_reg('cg')
VARS['sg%d'%i] = prog.curr_block.new_reg('sg')
if i % 10000000 == 0 and i > 0:
print "Initialized %d register variables at" % i, time.asctime()
print("Initialized %d register variables at" % i, time.asctime())
# first pass determines how many assembler registers are used
prog.FIRST_PASS = True
execfile(prog.infile, VARS)
exec(compile(open(prog.infile).read(), prog.infile, 'exec'), VARS)
if instructions_base.Instruction.count != 0:
print 'instructions count', instructions_base.Instruction.count
print('instructions count', instructions_base.Instruction.count)
instructions_base.Instruction.count = 0
prog.FIRST_PASS = False
prog.reset_values()
# make compiler modules directly accessible
sys.path.insert(0, 'Compiler')
# create the tapes
execfile(prog.infile, VARS)
exec(compile(open(prog.infile).read(), prog.infile, 'exec'), VARS)
# optimize the tapes
for tape in prog.tapes:
@@ -66,14 +66,14 @@ def run(args, options, param=-1, merge_opens=True, emulate=True, \
sharedmem = list(prog.mem_s)
prog.emulate()
if prog.mem_c != clearmem or prog.mem_s != sharedmem:
print 'Warning: emulated memory values changed after compiler optimization'
print('Warning: emulated memory values changed after compiler optimization')
# raise CompilerError('Compiler optimization caused incorrect memory write.')
if prog.main_thread_running:
prog.update_req(prog.curr_tape)
print 'Program requires:', repr(prog.req_num)
print 'Cost:', 0 if prog.req_num is None else prog.req_num.cost()
print 'Memory size:', dict(prog.allocated_mem)
print('Program requires:', repr(prog.req_num))
print('Cost:', 0 if prog.req_num is None else prog.req_num.cost())
print('Memory size:', dict(prog.allocated_mem))
# finalize the memory
prog.finalize_memory()

View File

@@ -86,8 +86,8 @@ class HeapQ(object):
self.size = MemValue(int_type(0))
self.int_type = int_type
self.basic_type = basic_type
print 'heap: %d levels, depth %d, size %d, index size %d' % \
(self.levels, self.depth, self.heap.oram.size, self.value_index.size)
print('heap: %d levels, depth %d, size %d, index size %d' % \
(self.levels, self.depth, self.heap.oram.size, self.value_index.size))
def update(self, value, prio, for_real=True):
self._update(self.basic_type.hard_conv(value), \
self.basic_type.hard_conv(prio), \
@@ -217,7 +217,7 @@ class HeapQ(object):
def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=sint):
basic_type = int_type.basic_type
vert_loops = n_loops * e_index.size / edges.size \
vert_loops = n_loops * e_index.size // edges.size \
if n_loops else -1
dist = oram_type(e_index.size, entry_size=(32,log2(e_index.size)), \
init_rounds=vert_loops, value_type=basic_type)
@@ -287,7 +287,7 @@ def test_dijkstra(G, source, oram_type=ORAM, n_loops=None, int_type=sint):
cint(i).print_reg('edge')
time()
edges[i] = edges_list[i]
vert_loops = n_loops * e_index.size / edges.size \
vert_loops = n_loops * e_index.size // edges.size \
if n_loops else e_index.size
for i in range(vert_loops):
cint(i).print_reg('vert')
@@ -307,7 +307,7 @@ def test_dijkstra_on_cycle(n, oram_type=ORAM, n_loops=None, int_type=sint):
time()
neighbour = ((i >> 1) + 2 * (i % 2) - 1 + n) % n
edges[i] = (neighbour, 1, i % 2)
vert_loops = n_loops * e_index.size / edges.size \
vert_loops = n_loops * e_index.size // edges.size \
if n_loops else e_index.size
@for_range(vert_loops)
def f(i):
@@ -390,14 +390,14 @@ class ExtInt(object):
class Vector(object):
""" Works like a vector. """
def __add__(self, other):
print 'add', type(self)
print('add', type(self))
res = type(self)(len(self))
@for_range(len(self))
def f(i):
res[i] = self[i] + other[i]
return res
def __sub__(self, other):
print 'sub', type(other)
print('sub', type(other))
res = type(other)(len(self))
@for_range(len(self))
def f(i):
@@ -412,7 +412,7 @@ class Vector(object):
res[0] += self[i] * other[i]
return res[0]
else:
print 'mul', type(self)
print('mul', type(self))
res = type(self)(len(self))
@for_range_parallel(1024, len(self))
def f(i):
@@ -477,7 +477,7 @@ def binarymin(A):
if len(A) == 1:
return [1], A[0]
else:
half = len(A) / 2
half = len(A) // 2
A_prime = VectorArray(half)
B = IntVectorArray(half)
i = IntVectorArray(len(A))

View File

@@ -1,9 +1,9 @@
from math import log, floor, ceil
from Compiler.instructions import *
import types
import comparison
import program
import util
from . import types
from . import comparison
from . import program
from . import util
##
## Helper functions for floating point arithmetic
@@ -16,7 +16,7 @@ def two_power(n):
else:
max = types.cint(1) << 31
res = 2**(n%31)
for i in range(n / 31):
for i in range(n // 31):
res *= max
return res
@@ -25,7 +25,7 @@ def shift_two(n, pos):
return n >> pos
else:
res = (n >> (pos%63))
for i in range(pos / 63):
for i in range(pos // 63):
res >>= 63
return res
@@ -139,7 +139,7 @@ def PreOpL(op, items):
kmax = 2**logk
output = list(items)
for i in range(logk):
for j in range(kmax/(2**(i+1))):
for j in range(kmax//(2**(i+1))):
y = two_power(i) + j*two_power(i+1) - 1
for z in range(1, 2**i+1):
if y+z < k:
@@ -153,7 +153,7 @@ def PreOpL2(op, items):
op must be a binary function that outputs a new register
"""
k = len(items)
half = k / 2
half = k // 2
output = list(items)
if k == 0:
return []
@@ -161,7 +161,7 @@ def PreOpL2(op, items):
v = PreOpL2(op, u)
for i in range(half):
output[2 * i + 1] = v[i]
for i in range(1, (k + 1) / 2):
for i in range(1, (k + 1) // 2):
output[2 * i] = op(v[i - 1], items[2 * i])
return output
@@ -185,8 +185,8 @@ def KOpL(op, a):
if k == 1:
return a[0]
else:
t1 = KOpL(op, a[:k/2])
t2 = KOpL(op, a[k/2:])
t1 = KOpL(op, a[:k//2])
t2 = KOpL(op, a[k//2:])
return op(t1, t2)
def KORL(a, kappa):
@@ -195,8 +195,8 @@ def KORL(a, kappa):
if k == 1:
return a[0]
else:
t1 = KORL(a[:k/2], kappa)
t2 = KORL(a[k/2:], kappa)
t1 = KORL(a[:k//2], kappa)
t2 = KORL(a[k//2:], kappa)
return t1 + t2 - t1*t2
def KORC(a, kappa):
@@ -234,7 +234,7 @@ def BitAdd(a, b, bits_to_compute=None):
bits s[0], ... , s[k] """
k = len(a)
if not bits_to_compute:
bits_to_compute = range(k)
bits_to_compute = list(range(k))
d = [None] * k
for i in range(1,k):
#assert(a[i].value == 0 or a[i].value == 1)
@@ -248,25 +248,25 @@ def BitAdd(a, b, bits_to_compute=None):
# (for testing)
def print_state():
print 'a: ',
print('a: ', end=' ')
for i in range(k):
print '%d ' % a[i].value,
print '\nb: ',
print('%d ' % a[i].value, end=' ')
print('\nb: ', end=' ')
for i in range(k):
print '%d ' % b[i].value,
print '\nd: ',
print('%d ' % b[i].value, end=' ')
print('\nd: ', end=' ')
for i in range(k):
print '%d ' % d[i][0].value,
print '\n ',
print('%d ' % d[i][0].value, end=' ')
print('\n ', end=' ')
for i in range(k):
print '%d ' % d[i][1].value,
print '\n\npg:',
print('%d ' % d[i][1].value, end=' ')
print('\n\npg:', end=' ')
for i in range(k):
print '%d ' % pg[i][0].value,
print '\n ',
print('%d ' % pg[i][0].value, end=' ')
print('\n ', end=' ')
for i in range(k):
print '%d ' % pg[i][1].value,
print ''
print('%d ' % pg[i][1].value, end=' ')
print('')
for bit in c:
pass#assert(bit.value == 0 or bit.value == 1)
@@ -281,7 +281,7 @@ def BitAdd(a, b, bits_to_compute=None):
try:
pass#assert(s[i].value == 0 or s[i].value == 1)
except AssertionError:
print '#assertion failed in BitAdd for s[%d]' % i
print('#assertion failed in BitAdd for s[%d]' % i)
print_state()
s[k] = c[k-1]
#print_state()
@@ -316,9 +316,9 @@ def BitDecField(a, k, m, kappa, bits_to_compute=None):
try:
pass#assert(c.value == (2**(k + kappa) + 2**k + (a.value%2**k) - rval) % comparison.program.P)
except AssertionError:
print 'BitDec assertion failed'
print 'a =', a.value
print 'a mod 2^%d =' % k, (a.value % 2**k)
print('BitDec assertion failed')
print('a =', a.value)
print('a mod 2^%d =' % k, (a.value % 2**k))
return types.intbitint.bit_adder(list(bits(c,m)), r)
@@ -503,7 +503,7 @@ def TruncPrRing(a, k, m, signed=True):
a += types.sint.get_random_bit() << i
return comparison.TruncLeakyInRing(a, k, m, signed=signed)
else:
from types import sint
from .types import sint
if signed:
a += (1 << (k - 1))
if program.Program.prog.use_trunc_pr:

View File

@@ -19,10 +19,10 @@ class SparseDiGraph(object):
if default_attributes is None:
default_attributes = { 'merges': None, 'stop': -1, 'start': -1 }
self.default_attributes = default_attributes
self.attribute_pos = dict(zip(default_attributes.keys(), range(len(default_attributes))))
self.attribute_pos = dict(list(zip(list(default_attributes.keys()), list(range(len(default_attributes))))))
self.n = max_nodes
# each node contains list of default attributes, followed by outoing edges
self.nodes = [self.default_attributes.values() for i in range(self.n)]
self.nodes = [list(self.default_attributes.values()) for i in range(self.n)]
self.succ = [set() for i in range(self.n)]
self.pred = [set() for i in range(self.n)]
self.weights = {}
@@ -45,7 +45,7 @@ class SparseDiGraph(object):
raise CompilerError('Cannot add node %d to graph of size %d' % (i, self.n))
node = self.nodes[i]
for a,value in attr.items():
for a,value in list(attr.items()):
if a in self.default_attributes:
node[self.attribute_pos[a]] = value
else:
@@ -72,7 +72,7 @@ class SparseDiGraph(object):
#del self.weights[(v,i)]
#self.nodes[v].remove(i)
self.pred[i] = []
self.nodes[i] = self.default_attributes.values()
self.nodes[i] = list(self.default_attributes.values())
def add_edge(self, i, j, weight=1):
if j not in self[i]:
@@ -111,7 +111,7 @@ def topological_sort(G, nbunch=None, pref=None):
return G[node]
else:
def get_children(node):
if pref.has_key(node):
if node in pref:
pref_set = set(pref[node])
for i in G[node]:
if i not in pref_set:
@@ -123,7 +123,7 @@ def topological_sort(G, nbunch=None, pref=None):
yield i
if nbunch is None:
nbunch = reversed(range(len(G)))
nbunch = reversed(list(range(len(G))))
for v in nbunch: # process all vertices in G
if v in explored:
continue
@@ -170,8 +170,8 @@ def reverse_dag_shortest_paths(G, source):
dist[source] = 0
for u in top_order:
if u ==68273:
print 'dist[68273]', dist[u]
print 'pred[u]', G.pred[u]
print('dist[68273]', dist[u])
print('pred[u]', G.pred[u])
if dist[u] is None:
continue
for v in G.pred[u]:
@@ -207,7 +207,7 @@ def longest_paths(G, sources=None):
G.weights[edge] = -G.weights[edge]
dist = {}
for source in sources:
print ('%s, ' % source),
print(('%s, ' % source), end=' ')
dist[source] = dag_shortest_paths(G, source)
#dist = johnson(G, sources)
# reset weights

View File

@@ -5,8 +5,8 @@ from Compiler import types
from Compiler.util import *
from oram import OptimalORAM,LinearORAM,RecursiveORAM,TrivialORAM,Entry
from library import for_range,do_while,time,start_timer,stop_timer,if_,print_ln,crash,print_str
from .oram import OptimalORAM,LinearORAM,RecursiveORAM,TrivialORAM,Entry
from .library import for_range,do_while,time,start_timer,stop_timer,if_,print_ln,crash,print_str
class OMatrixRow(object):
def __init__(self, oram, base, add_type):
@@ -27,7 +27,7 @@ class OMatrixRow(object):
class OMatrix:
def __init__(self, N, M=None, oram_type=OptimalORAM, int_type=types.sint):
print 'matrix', oram_type
print('matrix', oram_type)
self.N = N
self.M = M or N
self.oram = oram_type(N * self.M, entry_size=log2(N), init_rounds=0, \
@@ -73,7 +73,7 @@ class OReverseMatrix(OMatrix):
class OStack:
def __init__(self, N, oram_type=OptimalORAM, int_type=types.sint):
print 'stack', oram_type
print('stack', oram_type)
self.oram = oram_type(N, entry_size=log2(N), init_rounds=0, \
value_type=int_type.basic_type)
self.size = types.MemValue(int_type(0))
@@ -247,4 +247,4 @@ class Matchmaker:
self.reverse = reverse
self.int_type = int_type
self.basic_type = int_type.basic_type
print 'match', self.oram_type
print('match', self.oram_type)

View File

@@ -11,7 +11,7 @@ documentation
"""
import itertools
import tools
from . import tools
from random import randint
from Compiler.config import *
from Compiler.exceptions import *
@@ -51,7 +51,7 @@ class ldsi(base.Instruction):
@base.vectorize
class ldmc(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
r""" Assigns register $c_i$ the value in memory \verb+C[n]+. """
__slots__ = ["code"]
__slots__ = []
code = base.opcodes['LDMC']
arg_format = ['cw','int']
@@ -62,7 +62,7 @@ class ldmc(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
@base.vectorize
class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
r""" Assigns register $s_i$ the value in memory \verb+S[n]+. """
__slots__ = ["code"]
__slots__ = []
code = base.opcodes['LDMS']
arg_format = ['sw','int']
@@ -73,7 +73,7 @@ class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
@base.vectorize
class stmc(base.DirectMemoryWriteInstruction):
r""" Sets \verb+C[n]+ to be the value $c_i$. """
__slots__ = ["code"]
__slots__ = []
code = base.opcodes['STMC']
arg_format = ['c','int']
@@ -84,7 +84,7 @@ class stmc(base.DirectMemoryWriteInstruction):
@base.vectorize
class stms(base.DirectMemoryWriteInstruction):
r""" Sets \verb+S[n]+ to be the value $s_i$. """
__slots__ = ["code"]
__slots__ = []
code = base.opcodes['STMS']
arg_format = ['s','int']
@@ -94,7 +94,7 @@ class stms(base.DirectMemoryWriteInstruction):
@base.vectorize
class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
r""" Assigns register $ci_i$ the value in memory \verb+Ci[n]+. """
__slots__ = ["code"]
__slots__ = []
code = base.opcodes['LDMINT']
arg_format = ['ciw','int']
@@ -104,7 +104,7 @@ class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
@base.vectorize
class stmint(base.DirectMemoryWriteInstruction):
r""" Sets \verb+Ci[n]+ to be the value $ci_i$. """
__slots__ = ["code"]
__slots__ = []
code = base.opcodes['STMINT']
arg_format = ['ci','int']
@@ -227,7 +227,7 @@ class protectmemint(base.Instruction):
@base.vectorize
class movc(base.Instruction):
r""" Assigns register $c_i$ the value in the register $c_j$. """
__slots__ = ["code"]
__slots__ = []
code = base.opcodes['MOVC']
arg_format = ['cw','c']
@@ -238,7 +238,7 @@ class movc(base.Instruction):
@base.vectorize
class movs(base.Instruction):
r""" Assigns register $s_i$ the value in the register $s_j$. """
__slots__ = ["code"]
__slots__ = []
code = base.opcodes['MOVS']
arg_format = ['sw','s']
@@ -248,7 +248,7 @@ class movs(base.Instruction):
@base.vectorize
class movint(base.Instruction):
r""" Assigns register $ci_i$ the value in the register $ci_j$. """
__slots__ = ["code"]
__slots__ = []
code = base.opcodes['MOVINT']
arg_format = ['ciw','ci']
@@ -347,6 +347,21 @@ class use_prep(base.Instruction):
code = base.opcodes['USE_PREP']
arg_format = ['str','int']
class nplayers(base.Instruction):
r""" Number of players """
code = base.opcodes['NPLAYERS']
arg_format = ['ciw']
class threshold(base.Instruction):
r""" Maximal number of corrupt players """
code = base.opcodes['THRESHOLD']
arg_format = ['ciw']
class playerid(base.Instruction):
r""" My player number """
code = base.opcodes['PLAYERID']
arg_format = ['ciw']
###
### Basic arithmetic
###
@@ -738,7 +753,7 @@ class shrci(base.ClearShiftInstruction):
class triple(base.DataInstruction):
r""" Load secret variables $s_i$, $s_j$ and $s_k$
with the next multiplication triple. """
__slots__ = ['data_type']
__slots__ = []
code = base.opcodes['TRIPLE']
arg_format = ['sw','sw','sw']
data_type = 'triple'
@@ -752,7 +767,7 @@ class triple(base.DataInstruction):
class gbittriple(base.DataInstruction):
r""" Load secret variables $s_i$, $s_j$ and $s_k$
with the next GF(2) multiplication triple. """
__slots__ = ['data_type']
__slots__ = []
code = base.opcodes['GBITTRIPLE']
arg_format = ['sgw','sgw','sgw']
data_type = 'bittriple'
@@ -1400,7 +1415,8 @@ class convmodp(base.Instruction):
bitlength = program.bit_length if bitlength is None else bitlength
if bitlength > 64:
raise CompilerError('%d-bit conversion requested ' \
'but integer registers only have 64 bits')
'but integer registers only have 64 bits' % \
bitlength)
super(convmodp_class, self).__init__(*(args + (bitlength,)))
@base.vectorize
@@ -1433,7 +1449,7 @@ class muls(base.VarArgsInstruction, base.DataInstruction):
data_type = 'triple'
def get_repeat(self):
return len(self.args) / 3
return len(self.args) // 3
def merge_id(self):
# can merge different sizes
@@ -1508,6 +1524,8 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction):
for j in range(self.args[i] - 2):
yield 's' + field
gf2n_arg_format = arg_format
def bases(self):
i = 0
while i < len(self.args):
@@ -1515,7 +1533,7 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction):
i += self.args[i]
def get_repeat(self):
return sum(self.args[i] / 2 for i in self.bases()) * self.get_size()
return sum(self.args[i] // 2 for i in self.bases()) * self.get_size()
def get_def(self):
return [self.args[i + 1] for i in self.bases()]
@@ -1567,7 +1585,7 @@ class lts(base.CISC):
arg_format = ['sw', 's', 's', 'int', 'int']
def expand(self):
from types import sint
from .types import sint
a = sint()
subs(a, self.args[1], self.args[2])
comparison.LTZ(self.args[0], a, self.args[3], self.args[4])

View File

@@ -56,6 +56,9 @@ opcodes = dict(
USE_PREP = 0x1C,
STARTGRIND = 0x1D,
STOPGRIND = 0x1E,
NPLAYERS = 0xE2,
THRESHOLD = 0xE3,
PLAYERID = 0xE4,
# Addition
ADDC = 0x20,
ADDS = 0x21,
@@ -277,7 +280,7 @@ def gf2n(instruction):
vectorized GF_2^n instruction if a modp version exists. """
global_dict = inspect.getmodule(instruction).__dict__
if global_dict.has_key('v' + instruction.__name__):
if 'v' + instruction.__name__ in global_dict:
vectorized = True
else:
vectorized = False
@@ -316,7 +319,7 @@ def gf2n(instruction):
if 'gf2n_arg_format' in instruction_cls.__dict__:
arg_format = instruction_cls.gf2n_arg_format
elif isinstance(instruction_cls.arg_format, itertools.repeat):
__f = instruction_cls.arg_format.next()
__f = next(instruction_cls.arg_format)
if __f != 'int' and __f != 'p':
arg_format = itertools.repeat(__f[0] + 'g' + __f[1:])
else:
@@ -420,7 +423,7 @@ class ClearIntAF(RegisterArgFormat):
class IntArgFormat(ArgFormat):
@classmethod
def check(cls, arg):
if not isinstance(arg, (int, long)):
if not isinstance(arg, int):
raise ArgumentError(arg, 'Expected an integer-valued argument')
@classmethod
@@ -512,7 +515,7 @@ class Instruction(object):
Instruction.count += 1
if Instruction.count % 100000 == 0:
print "Compiled %d lines at" % self.__class__.count, time.asctime()
print("Compiled %d lines at" % self.__class__.count, time.asctime())
def get_code(self):
return self.code
@@ -535,7 +538,7 @@ class Instruction(object):
def check_args(self):
""" Check the args match up with that specified in arg_format """
for n,(arg,f) in enumerate(itertools.izip_longest(self.args, self.arg_format)):
for n,(arg,f) in enumerate(itertools.zip_longest(self.args, self.arg_format)):
if arg is None:
if not isinstance(self.arg_format, (list, tuple)):
break # end of optional arguments
@@ -758,8 +761,6 @@ class ClearShiftInstruction(ClearImmediate):
else:
# assume 64-bit machine
bits = 63
elif program.options.ring:
bits = int(program.options.ring) - 1
if self.args[2] > bits:
raise CompilerError('Shifting by more than %d bits '
'not implemented' % bits)

View File

@@ -1,4 +1,4 @@
from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat, _single
from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat, _single, localint
from Compiler.instructions import *
from Compiler.util import tuplify,untuplify
from Compiler import instructions,instructions_base,comparison,program,util
@@ -6,6 +6,7 @@ import inspect,math
import random
import collections
import operator
from functools import reduce
def get_program():
return instructions.program
@@ -101,6 +102,8 @@ def print_ln_if(cond, ss, *args):
else:
subs = ss.split('%s')
assert len(subs) == len(args) + 1
if isinstance(cond, localint):
cond = cond._v
cond = cint.conv(cond)
for i, s in enumerate(subs):
if i != 0:
@@ -274,7 +277,7 @@ class FunctionTapeCall:
def join(self):
self.thread.join()
instructions.program.free(self.base, 'ci')
for reg_type,addr in self.bases.iteritems():
for reg_type,addr in self.bases.items():
get_program().free(addr, reg_type.reg_type)
class Function:
@@ -287,7 +290,7 @@ class Function:
self.compile_args = compile_args
def __call__(self, *args):
args = tuple(arg.read() if isinstance(arg, MemValue) else arg for arg in args)
get_reg_type = lambda x: regint if isinstance(x, (int, long)) else type(x)
get_reg_type = lambda x: regint if isinstance(x, int) else type(x)
if len(args) not in self.type_args:
# first call
type_args = collections.defaultdict(list)
@@ -296,9 +299,11 @@ class Function:
def wrapped_function(*compile_args):
base = get_arg()
bases = dict((t, regint.load_mem(base + i)) \
for i,t in enumerate(sorted(type_args)))
for i,t in enumerate(sorted(type_args,
key=lambda x:
x.reg_type)))
runtime_args = [None] * len(args)
for t in sorted(type_args):
for t in sorted(type_args, key=lambda x: x.reg_type):
for i,i_arg in enumerate(type_args[t]):
runtime_args[i_arg] = t.load_mem(bases[t] + i)
return self.function(*(list(compile_args) + runtime_args))
@@ -308,7 +313,8 @@ class Function:
base = instructions.program.malloc(len(type_args), 'ci')
bases = dict((t, get_program().malloc(len(type_args[t]), t)) \
for t in type_args)
for i,reg_type in enumerate(sorted(type_args)):
for i,reg_type in enumerate(sorted(type_args,
key=lambda x: x.reg_type)):
store_in_mem(bases[reg_type], base + i)
for j,i_arg in enumerate(type_args[reg_type]):
if get_reg_type(args[i_arg]) != reg_type:
@@ -353,13 +359,13 @@ class FunctionBlock(Function):
block.alloc_pool = defaultdict(set)
del parent_node.children[-1]
self.node = get_tape().req_node
print 'Compiling function', self.name
print('Compiling function', self.name)
result = wrapped_function(*self.compile_args)
if result is not None:
self.result = memorize(result)
else:
self.result = None
print 'Done compiling function', self.name
print('Done compiling function', self.name)
p_return_address = get_tape().program.malloc(1, 'ci')
get_tape().function_basicblocks[block] = p_return_address
return_address = regint.load_mem(p_return_address)
@@ -429,7 +435,7 @@ def sort(a):
res = a
for i in range(len(a)):
for j in reversed(range(i)):
for j in reversed(list(range(i))):
res[j], res[j+1] = cond_swap(res[j], res[j+1])
return res
@@ -443,7 +449,7 @@ def odd_even_merge(a):
odd_even_merge(even)
odd_even_merge(odd)
a[0] = even[0]
for i in range(1, len(a) / 2):
for i in range(1, len(a) // 2):
a[2*i-1], a[2*i] = cond_swap(odd[i-1], even[i])
a[-1] = odd[-1]
@@ -451,8 +457,8 @@ def odd_even_merge_sort(a):
if len(a) == 1:
return
elif len(a) % 2 == 0:
lower = a[:len(a)/2]
upper = a[len(a)/2:]
lower = a[:len(a)//2]
upper = a[len(a)//2:]
odd_even_merge_sort(lower)
odd_even_merge_sort(upper)
a[:] = lower + upper
@@ -472,10 +478,10 @@ def chunky_odd_even_merge_sort(a):
def round():
for i in range(len(a)):
a[i] = type(a[i]).load_mem(i * a[i].sizeof())
for i in range(len(a) / l):
for j in range(l / k):
for i in range(len(a) // l):
for j in range(l // k):
base = i * l + j
step = l / k
step = l // k
if k == 2:
a[base], a[base+step] = cond_swap(a[base], a[base+step])
else:
@@ -514,7 +520,7 @@ def chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7, use
def run_chunk(size, base):
if size not in chunks:
def swap_list(list_base):
for i in range(size / 2):
for i in range(size // 2):
base = list_base + 2 * i
x, y = cond_swap(load_secret_mem(base),
load_secret_mem(base + 1))
@@ -526,8 +532,8 @@ def chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7, use
def run_round(size):
# minimize number of chunk sizes
n_chunks = int(math.ceil(1.0 * size / max_chunk_size))
lower_size = size / n_chunks / 2 * 2
n_lower_size = n_chunks - (size - n_chunks * lower_size) / 2
lower_size = size // n_chunks // 2 * 2
n_lower_size = n_chunks - (size - n_chunks * lower_size) // 2
# print len(to_swap) == lower_size * n_lower_size + \
# (lower_size + 2) * (n_chunks - n_lower_size), \
# len(to_swap), n_chunks, lower_size, n_lower_size
@@ -603,10 +609,10 @@ def chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7, use
k *= 2
size = 0
instructions.program.curr_tape.merge_opens = False
for i in range(n / l):
for j in range(l / k):
for i in range(n // l):
for j in range(l // k):
base = i * l + j
step = l / k
step = l // k
size += run_setup(k, a_base + base, step, tmp_base + size)
run_threads_in_rounds(pre_threads)
run_round(size)
@@ -651,7 +657,7 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=
def run_chunk(size, base):
if size not in chunks:
def swap_list(list_base):
for i in range(size / 2):
for i in range(size // 2):
base = list_base + 2 * i
x, y = cond_swap(load_secret_mem(base),
load_secret_mem(base + 1))
@@ -663,8 +669,8 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=
def run_round(size):
# minimize number of chunk sizes
n_chunks = int(math.ceil(1.0 * size / max_chunk_size))
lower_size = size / n_chunks / 2 * 2
n_lower_size = n_chunks - (size - n_chunks * lower_size) / 2
lower_size = size // n_chunks // 2 * 2
n_lower_size = n_chunks - (size - n_chunks * lower_size) // 2
# print len(to_swap) == lower_size * n_lower_size + \
# (lower_size + 2) * (n_chunks - n_lower_size), \
# len(to_swap), n_chunks, lower_size, n_lower_size
@@ -692,7 +698,7 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=
def outer(i):
def inner(j):
base = j
step = l / k
step = l // k
if k == 2:
tmp_addr = regint.load_mem(tmp_i)
load_and_store(base, tmp_addr)
@@ -704,19 +710,19 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=
load_and_store(m, tmp_addr)
store_in_mem(tmp_addr + 1, tmp_i)
range_loop(inner2, base + step, base + (k - 1) * step, step)
range_loop(inner, a_base + i * l, a_base + i * l + l / k)
range_loop(inner, a_base + i * l, a_base + i * l + l // k)
instructions.program.curr_tape.merge_opens = False
to_tmp = True
store_in_mem(tmp_base, tmp_i)
range_loop(outer, n / l)
range_loop(outer, n // l)
if k == 2:
run_round(n)
else:
run_round(n / k * (k - 2))
run_round(n // k * (k - 2))
instructions.program.curr_tape.merge_opens = False
to_tmp = False
store_in_mem(tmp_base, tmp_i)
range_loop(outer, n / l)
range_loop(outer, n // l)
if isinstance(a, list):
instructions.program.restart_main_thread()
@@ -734,15 +740,15 @@ def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32):
k = 1
while k < l:
k *= 2
n_outer = len(a) / l
n_inner = l / k
n_innermost = 1 if k == 2 else k / 2 - 1
@for_range_parallel(n_parallel / n_innermost / n_inner, n_outer)
n_outer = len(a) // l
n_inner = l // k
n_innermost = 1 if k == 2 else k // 2 - 1
@for_range_parallel(n_parallel // n_innermost // n_inner, n_outer)
def loop(i):
@for_range_parallel(n_parallel / n_innermost, n_inner)
@for_range_parallel(n_parallel // n_innermost, n_inner)
def inner(j):
base = i*l + j
step = l/k
step = l//k
if k == 2:
a[base], a[base+step] = cond_swap(a[base], a[base+step])
else:
@@ -805,7 +811,7 @@ def range_loop(loop_body, start, stop=None, step=None):
# known loop count
if condition(start):
get_tape().req_node.children[-1].aggregator = \
lambda x: ((stop - start) / step) * x[0]
lambda x: ((stop - start) // step) * x[0]
def for_range(start, stop=None, step=None):
""" Execute loop bodies consecutively """
@@ -840,7 +846,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
my_n_parallel = n_parallel
if isinstance(n_parallel, int):
if isinstance(n_loops, int):
loop_rounds = n_loops / n_parallel \
loop_rounds = n_loops // n_parallel \
if n_parallel < n_loops else 0
else:
loop_rounds = n_loops / n_parallel
@@ -884,7 +890,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
regint.push(k)
return i + k
my_n_parallel = n_opt_loops
loop_rounds = n_loops / my_n_parallel
loop_rounds = n_loops // my_n_parallel
blocks = get_tape().basicblocks
n_to_merge = 5
if loop_rounds == 1 and parent_block is blocks[-n_to_merge]:
@@ -966,7 +972,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
indices = []
for n in reversed(split):
indices.insert(0, i % n)
i /= n
i //= n
return loop_body(*indices)
return new_body
new_dec = map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, thread_mem_req)
@@ -979,7 +985,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
else:
return dec
def decorator(loop_body):
thread_rounds = n_loops / n_threads
thread_rounds = n_loops // n_threads
remainder = n_loops % n_threads
for t in thread_mem_req:
if t != regint:
@@ -1233,10 +1239,25 @@ def stop_timer(timer_id=0):
stop(timer_id)
get_tape().start_new_basicblock(name='post-stop-timer')
def get_number_of_players():
res = regint()
nplayers(res)
return res
def get_threshold():
res = regint()
threshold(res)
return res
def get_player_id():
res = localint()
playerid(res._v)
return res
# Fixed point ops
from math import ceil, log
from floatingpoint import PreOR, TruncPr, two_power, shift_two
from .floatingpoint import PreOR, TruncPr, two_power, shift_two
def approximate_reciprocal(divisor, k, f, theta):
"""
@@ -1369,7 +1390,7 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
# no probabilistic truncation in binary circuits
nearest = True
res_f = f
f = max((k - nearest) / 2 + 1, f)
f = max((k - nearest) // 2 + 1, f)
assert 2 * f > k - nearest
theta = int(ceil(log(k/3.5) / log(2)))
alpha = b.get_type(2 * k).two_power(2*f)
@@ -1387,7 +1408,7 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
x = x.round(2*k, 2*f, kappa, nearest, signed=True)
y = y.extend(2 * k) * (alpha + x).extend(2 * k)
y = y.round(k + 2 * f, 3 * f - res_f, kappa, nearest, signed=True)
y = y.round(k + 3 * f - res_f, 3 * f - res_f, kappa, nearest, signed=True)
return y
def AppRcr(b, k, f, kappa, simplex_flag=False, nearest=False):
"""

View File

@@ -3,6 +3,7 @@ import mpc_math, math
from Compiler.types import *
from Compiler.types import _unreduced_squant
from Compiler.library import *
from functools import reduce
def log_e(x):
return mpc_math.log_fx(x, math.e)
@@ -129,7 +130,7 @@ class DenseBase(Layer):
tmp[j][k] = sfix.unreduced_dot_product(a, b)
if self.d_in * self.d_out < 100000:
print 'reduce at once'
print('reduce at once')
@multithread(self.n_threads, self.d_in * self.d_out)
def _(base, size):
self.nabla_W.assign_vector(
@@ -386,7 +387,7 @@ class QuantConvBase(QuantBase):
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
self.weights.input_from(player, budget=100000)
self.bias.input_from(player)
print 'WARNING: assuming that bias quantization parameters are correct'
print('WARNING: assuming that bias quantization parameters are correct')
self.output_squant.params.precompute(self.input_squant.params, self.weight_squant.params)
@@ -404,7 +405,7 @@ class QuantConvBase(QuantBase):
start_timer(2)
n_outputs = reduce(operator.mul, self.output_shape)
if n_outputs % self.n_threads == 0:
n_per_thread = n_outputs / self.n_threads
n_per_thread = n_outputs // self.n_threads
@for_range_opt_multithread(self.n_threads, self.n_threads)
def _(i):
res = _unreduced_squant(
@@ -556,7 +557,7 @@ class QuantAveragePool2d(QuantBase):
self.filter_size = filter_size
def input_from(self, player):
print 'WARNING: assuming that input and output quantization parameters are the same'
print('WARNING: assuming that input and output quantization parameters are the same')
for s in self.input_squant, self.output_squant:
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
@@ -567,7 +568,7 @@ class QuantAveragePool2d(QuantBase):
_, output_h, output_w, n_channels_out = self.output_shape
n = input_h * input_w
print 'divisor: ', n
print('divisor: ', n)
assert output_h == output_w == 1
assert n_channels_in == n_channels_out
@@ -599,7 +600,7 @@ class QuantAveragePool2d(QuantBase):
acc += self.X[0][in_y][in_x][c].v
#fc += 1
logn = int(math.log(n, 2))
acc = (acc + n / 2)
acc = (acc + n // 2)
if 2 ** logn == n:
acc = acc.round(self.output_squant.params.k + logn,
logn, nearest=True)
@@ -614,7 +615,7 @@ class QuantReshape(QuantBase):
super(QuantReshape, self).__init__(input_shape, output_shape)
def input_from(self, player):
print 'WARNING: assuming that input and output quantization parameters are the same'
print('WARNING: assuming that input and output quantization parameters are the same')
_ = self.new_squant()
for s in self.input_squant, _, self.output_squant:
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
@@ -628,7 +629,7 @@ class QuantReshape(QuantBase):
class QuantSoftmax(QuantBase):
def input_from(self, player):
print 'WARNING: assuming that input and output quantization parameters are the same'
print('WARNING: assuming that input and output quantization parameters are the same')
for s in self.input_squant, self.output_squant:
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
@@ -666,14 +667,14 @@ class Optimizer:
N = self.layers[0].N
assert self.layers[-1].N == N
assert N % 2 == 0
n = N / 2
n = N // 2
@for_range(n)
def _(i):
self.layers[-1].Y[i] = 0
self.layers[-1].Y[i + n] = 1
n_per_epoch = int(math.ceil(1. * max(len(X) for X in
self.X_by_label) / n))
print '%d runs per epoch' % n_per_epoch
print('%d runs per epoch' % n_per_epoch)
indices_by_label = []
for label, X in enumerate(self.X_by_label):
indices = regint.Array(n * n_per_epoch)
@@ -794,8 +795,8 @@ class SGD(Optimizer):
x = x.reveal()
print_ln_if((x > 1000) + (x < -1000),
name + ': %s %s %s %s',
*[y.v.reveal() for y in old, red_old, \
new, diff])
*[y.v.reveal() for y in (old, red_old, \
new, diff)])
if self.debug:
d = delta_theta.get_vector().reveal()
a = cfix.Array(len(d.v))

View File

@@ -474,7 +474,7 @@ def norm_simplified_SQ(b, k):
m_odd = m_odd + z[i]
# construct w,
k_over_2 = k / 2 + 1
k_over_2 = k // 2 + 1
w_array = [0] * (k_over_2)
w_array[0] = z[0]
for i in range(1, k_over_2):
@@ -510,7 +510,7 @@ def sqrt_simplified_fx(x):
m_odd = (1 - 2 * m_odd) + m_odd
w = (w * 2 - w) * (1-m_odd) + w
# map number to use sfix format and instantiate the number
w = types.sfix(w * 2 ** ((x.f - (x.f % 2)) / 2))
w = types.sfix(w * 2 ** ((x.f - (x.f % 2)) // 2))
# obtains correct 2 ** (m/2)
w = (w * (types.cfix(2 ** (1/2.0))) - w) * m_odd + w
# produce x/ 2^(m/2)

View File

@@ -4,6 +4,7 @@ import collections
import itertools
import operator
import sys
from functools import reduce
from Compiler.types import *
from Compiler.types import _secret
@@ -95,7 +96,7 @@ class gf2nBlock(Block):
prod_bits = [start * bit for bit in value_bits]
anti_bits = [v - p for v,p in zip(value_bits,prod_bits)]
self.lower = sum(bit << i for i,bit in enumerate(prod_bits[:length]))
self.bits = map(operator.add, anti_bits[:length], prod_bits[length:]) + \
self.bits = list(map(operator.add, anti_bits[:length], prod_bits[length:])) + \
anti_bits[length:]
self.adjust = if_else(start, 1 << length, cgf2n(1))
elif entries_per_block < 4:
@@ -105,7 +106,7 @@ class gf2nBlock(Block):
choice_bits = demux(start_bits)
inv_bits = [1 - bit for bit in floatingpoint.PreOR(choice_bits, None)]
mask_bits = sum(([x] * length for x in inv_bits), [])
lower_bits = map(operator.mul, value_bits, mask_bits)
lower_bits = list(map(operator.mul, value_bits, mask_bits))
self.lower = sum(bit << i for i,bit in enumerate(lower_bits))
self.bits = [sum(map(operator.mul, choice_bits, value_bits[i::length])) \
for i in range(length)]
@@ -124,7 +125,7 @@ class gf2nBlock(Block):
pre_bits = floatingpoint.PreOpL(lambda x,y,z=None: x + y, bits)
inv_bits = [1 - bit for bit in pre_bits]
mask_bits = sum(([x] * length for x in inv_bits), [])
lower_bits = map(operator.mul, value_bits, mask_bits)
lower_bits = list(map(operator.mul, value_bits, mask_bits))
masked = self.value - sum(bit << i for i,bit in enumerate(lower_bits))
self.lower = sum(bit << i for i,bit in enumerate(lower_bits))
self.bits = (masked / adjust).bit_decompose(used_bits)
@@ -177,12 +178,12 @@ def demux_list(x):
return [1]
elif n == 1:
return [1 - x[0], x[0]]
a = demux_list(x[:n/2])
b = demux_list(x[n/2:])
a = demux_list(x[:n//2])
b = demux_list(x[n//2:])
n_a = len(a)
a *= len(b)
b = reduce(operator.add, ([i] * n_a for i in b))
res = map(operator.mul, a, b)
res = list(map(operator.mul, a, b))
return res
def demux_array(x, res=None):
@@ -193,12 +194,12 @@ def demux_array(x, res=None):
res[0] = 1 - x[0]
res[1] = x[0]
else:
a = Array(2**(n/2), type(x[0]))
a.assign(demux(x[:n/2]))
b = Array(2**(n-n/2), type(x[0]))
b.assign(demux(x[n/2:]))
a = Array(2**(n//2), type(x[0]))
a.assign(demux(x[:n//2]))
b = Array(2**(n-n//2), type(x[0]))
b.assign(demux(x[n//2:]))
@for_range_multithread(get_n_threads(len(res)), \
max(1, n_parallel / len(b)), len(a))
max(1, n_parallel // len(b)), len(a))
def f(i):
@for_range_parallel(n_parallel, len(b))
def f(j):
@@ -234,7 +235,7 @@ class Value(object):
return Value(other * self.value, other * self.empty)
__rmul__ = __mul__
def equal(self, other, length=None):
if isinstance(other, (int, long)) and isinstance(self.value, (int, long)):
if isinstance(other, int) and isinstance(self.value, int):
return (1 - self.empty) * (other == self.value)
return (1 - self.empty) * self.value.equal(other, length)
def reveal(self):
@@ -252,9 +253,9 @@ class Value(object):
try:
value = self.empty
while True:
if value in (1, 1L):
if value == 1:
return '<>'
if value in (0, 0L):
if value == 0:
return '<%s>' % str(self.value)
value = value.value
except:
@@ -297,8 +298,8 @@ class Entry(object):
self.created_non_empty = False
if x is None:
v = iter(v)
self.is_empty = v.next()
self.v = v.next()
self.is_empty = next(v)
self.v = next(v)
self.x = ValueTuple(v)
else:
if empty is None:
@@ -332,7 +333,7 @@ class Entry(object):
try:
return Entry(i + j for i,j in zip(self, other))
except:
print self, other
print(self, other)
raise
def __sub__(self, other):
return Entry(i - j for i,j in zip(self, other))
@@ -342,7 +343,7 @@ class Entry(object):
try:
return Entry(other * i for i in self)
except:
print self, other
print(self, other)
raise
__rmul__ = __mul__
def reveal(self):
@@ -372,8 +373,8 @@ class RefRAM(object):
for t,array in zip(self.entry_type,oram.ram.l)]
self.index = index
def init_mem(self, empty_entry):
print 'init ram'
for a,value in zip(self.l, empty_entry.defaults.values()):
print('init ram')
for a,value in zip(self.l, list(empty_entry.defaults.values())):
# don't use threads if n_threads explicitly set to 1
a.assign_all(value, n_threads != 1, conv=False)
def get_empty_bits(self):
@@ -392,14 +393,14 @@ class RefRAM(object):
return [Value(self.l[2+index][i], self.l[0][i]) for i in range(self.size)]
def __getitem__(self, index):
if print_access:
print 'get', id(self), index
print('get', id(self), index)
return Entry(a[index] for a in self.l)
def __setitem__(self, index, value):
if print_access:
print 'set', id(self), index
print('set', id(self), index)
if not isinstance(value, Entry):
raise Exception('entries only please: %s' % str(value))
for i,(a,v) in enumerate(zip(self.l, value.values())):
for i,(a,v) in enumerate(zip(self.l, list(value.values()))):
a[index] = v
def __len__(self):
return self.size
@@ -524,7 +525,7 @@ class RefTrivialORAM(EndRecursiveEviction):
self.value_type, self.entry_size = oram.internal_entry_size()
self.size = oram.bucket_size
def init_mem(self):
print 'init trivial oram'
print('init trivial oram')
self.ram.init_mem(self.empty_entry(apply_type=False))
def search(self, read_index):
if use_binary_search and self.value_type == sgf2n:
@@ -554,7 +555,7 @@ class RefTrivialORAM(EndRecursiveEviction):
self.last_index = read_index
found, empty = self.search(read_index)
entries = [entry for entry in self.ram]
prod_entries = map(operator.mul, found, entries)
prod_entries = list(map(operator.mul, found, entries))
read_value = sum((entry.x.skip(skip) for entry in prod_entries), \
empty * empty_entry.x.skip(skip))
for i,(entry, prod_entry) in enumerate(zip(entries, prod_entries)):
@@ -566,7 +567,7 @@ class RefTrivialORAM(EndRecursiveEviction):
def read_and_remove_by_public(self, index):
empty_entry = self.empty_entry(False)
entries = [entry for entry in self.ram]
prod_entries = map(operator.mul, index, entries)
prod_entries = list(map(operator.mul, index, entries))
read_entry = reduce(operator.add, prod_entries)
for i,(entry, prod_entry) in enumerate(zip(entries, prod_entries)):
self.ram[i] = entry - prod_entry + index[i] * empty_entry
@@ -574,7 +575,7 @@ class RefTrivialORAM(EndRecursiveEviction):
@method_block
def _read(self, index):
found, empty = self.search(index)
read_value = sum(map(operator.mul, found, self.ram.get_values()), \
read_value = sum(list(map(operator.mul, found, self.ram.get_values())), \
empty * self.empty_entry(False).x)
return read_value, empty
@method_block
@@ -583,8 +584,8 @@ class RefTrivialORAM(EndRecursiveEviction):
found, not_found = self.search(index)
add_here = self.find_first_empty()
entries = [entry for entry in self.ram]
prod_values = map(operator.mul, found, \
(entry.x for entry in entries))
prod_values = list(map(operator.mul, found, \
(entry.x for entry in entries)))
read_value = sum(prod_values, not_found * empty_entry.x)
new_value = ValueTuple(new_value) \
if isinstance(new_value, (tuple, list)) \
@@ -699,15 +700,15 @@ class RefTrivialORAM(EndRecursiveEviction):
for k in range(2**(j)):
t = k + 2**(j) - 1
if k % 2 == 0:
M += bit_prods[(t-1)/2] * mult_tree[t]
M += bit_prods[(t-1)//2] * mult_tree[t]
b = 1 - M.equal(0, 40, expand)
for k in range(2**j):
t = k + 2**j - 1
if k % 2 == 0:
v = bit_prods[(t-1)/2] * b
bit_prods[t] = bit_prods[(t-1)/2] - v
v = bit_prods[(t-1)//2] * b
bit_prods[t] = bit_prods[(t-1)//2] - v
else:
bit_prods[t] = v
return bit_prods[n-1:n-1+self.size], 1 - bit_prods[0]
@@ -734,7 +735,7 @@ class RefTrivialORAM(EndRecursiveEviction):
print_ln('Bucket overflow')
crash()
if debug and not sum(add_here) and not new_entry.empty():
print self.empty_entry()
print(self.empty_entry())
raise Exception('no space for %s in %s' % (str(new_entry), str(self)))
self.check(new_entry=new_entry, op='add')
def pop(self):
@@ -746,7 +747,7 @@ class RefTrivialORAM(EndRecursiveEviction):
pop_here = [prefix_empty[i+1] - prefix_empty[i] \
for i in range(len(self.ram))]
entries = [entry for entry in self.ram]
prod_entries = map(operator.mul, pop_here, self.ram)
prod_entries = list(map(operator.mul, pop_here, self.ram))
result = (1 - sum(pop_here)) * empty_entry
result = sum(prod_entries, result)
for i,(entry, prod_entry) in enumerate(zip(entries, prod_entries)):
@@ -980,7 +981,7 @@ class LocalIndexStructure(List):
@for_range(init_rounds if init_rounds > 0 else size)
def f(i):
self.l[0][i] = random_block(entry_size, value_type)
print 'index size:', size
print('index size:', size)
def update(self, index, value, evict=None):
read_value = self[index]
#print 'read', index, read_value
@@ -1005,7 +1006,7 @@ class TreeORAM(AbstractORAM):
""" Tree ORAM. """
def __init__(self, size, value_type=sint, value_length=1, entry_size=None, \
bucket_oram=TrivialORAM, init_rounds=-1):
print 'create oram of size', size
print('create oram of size', size)
self.bucket_oram = bucket_oram
# heuristic bucket size
delta = 3
@@ -1013,9 +1014,9 @@ class TreeORAM(AbstractORAM):
# size + 1 for bucket overflow check
self.bucket_size = min(int(math.ceil((1 + delta) * k)), size + 1)
self.D = log2(max(size / k, 2))
print 'bucket size:', self.bucket_size
print 'depth:', self.D
print 'complexity:', self.bucket_size * (self.D + 1)
print('bucket size:', self.bucket_size)
print('depth:', self.D)
print('complexity:', self.bucket_size * (self.D + 1))
self.value_type = value_type
if entry_size is not None:
self.value_length = len(tuplify(entry_size))
@@ -1279,8 +1280,8 @@ class TreeORAM(AbstractORAM):
# split into 2 if bucket size can't fit into one field elem
if self.bucket_size + Program.prog.security > 128:
parity = (empty_positions[i]+1) % 2
half = (empty_positions[i]+1 - parity) / 2
half_max = self.bucket_size / 2
half = (empty_positions[i]+1 - parity) // 2
half_max = self.bucket_size // 2
bits = floatingpoint.B2U(half, half_max, Program.prog.security)[0]
bits2 = floatingpoint.B2U(half+parity, half_max, Program.prog.security)[0]
@@ -1384,11 +1385,11 @@ def get_parallel(index_size, value_type, value_length):
value_size = get_value_size(value_type)
if value_type == sint:
value_size *= 2
res = max(1, min(50 * 32 / (value_length * value_size), \
800 * 32 / (value_length * index_size)))
res = max(1, min(50 * 32 // (value_length * value_size), \
800 * 32 // (value_length * index_size)))
if comparison.const_rounds:
res = max(1, res / 2)
print 'Reading %d buckets in parallel' % res
res = max(1, res // 2)
print('Reading %d buckets in parallel' % res)
return res
class PackedIndexStructure(object):
@@ -1403,7 +1404,7 @@ class PackedIndexStructure(object):
self.value_type = value_type
for demux_bits in range(max_demux_bits + 1):
self.log_entries_per_element = min(log2(size), \
int(math.floor(math.log(float(get_value_size(value_type)) / \
int(math.floor(math.log(float(get_value_size(value_type)) // \
sum(self.entry_size), 2))))
self.log_elements_per_block = \
max(0, min(demux_bits, log2(size) - \
@@ -1423,24 +1424,24 @@ class PackedIndexStructure(object):
self.elements_per_entry = len(self.split_sizes)
self.log_elements_per_block = log2(self.elements_per_entry)
self.log_entries_per_element = -self.log_elements_per_block
print 'split sizes:', self.split_sizes
print('split sizes:', self.split_sizes)
self.log_entries_per_block = \
self.log_elements_per_block + self.log_entries_per_element
self.elements_per_block = 2**self.log_elements_per_block
self.entries_per_element = 2**self.log_entries_per_element
self.entries_per_block = 2**self.log_entries_per_block
self.used_bits = self.entries_per_element * sum(self.entry_size)
real_size = -(-size / self.entries_per_block)
print 'packed size:', real_size
print 'index size:', size
print 'entry size:', self.entry_size
print 'log(entries per element):', self.log_entries_per_element
print 'entries per element:', self.entries_per_element
print 'log(entries per block):', self.log_entries_per_block
print 'entries per block:', self.entries_per_block
print 'log(elements per block):', self.log_elements_per_block
print 'elements per block:', self.elements_per_block
print 'used bits:', self.used_bits
real_size = -(-size // self.entries_per_block)
print('packed size:', real_size)
print('index size:', size)
print('entry size:', self.entry_size)
print('log(entries per element):', self.log_entries_per_element)
print('entries per element:', self.entries_per_element)
print('log(entries per block):', self.log_entries_per_block)
print('entries per block:', self.entries_per_block)
print('log(elements per block):', self.log_elements_per_block)
print('elements per block:', self.elements_per_block)
print('used bits:', self.used_bits)
entry_size = [self.used_bits] * self.elements_per_block
if real_size > 1:
# no need to init underlying ORAM, will be initialized implicitely
@@ -1454,10 +1455,10 @@ class PackedIndexStructure(object):
self.index_type = self.l.index_type
if init_rounds:
if init_rounds > 0:
real_init_rounds = init_rounds * real_size / size
real_init_rounds = init_rounds * real_size // size
else:
real_init_rounds = real_size
print 'packed init rounds:', real_init_rounds
print('packed init rounds:', real_init_rounds)
@for_range(real_init_rounds)
def f(i):
if random_init:
@@ -1467,7 +1468,7 @@ class PackedIndexStructure(object):
self.l[i] = [0] * self.elements_per_block
time()
print_ln('packed ORAM init %s/%s', i, real_init_rounds)
print 'index initialized, size', size
print('index initialized, size', size)
def translate_index(self, index):
""" Bit slicing *index* according parameters. Output is tuple
(storage address, index with storage cell, index within
@@ -1501,16 +1502,16 @@ class PackedIndexStructure(object):
self.block = block
self.index_vector = \
demux(bit_decompose(self.b, self.pack.log_elements_per_block))
self.vector = map(operator.mul, self.index_vector, block)
self.vector = list(map(operator.mul, self.index_vector, block))
self.element = get_block(sum(self.vector), self.c, \
self.pack.entry_size, \
self.pack.entries_per_element)
return tuple(self.element.get_slice())
def write(self, value):
self.element.set_slice(value)
anti_vector = map(operator.sub, self.block, self.vector)
anti_vector = list(map(operator.sub, self.block, self.vector))
updated_vector = [self.element.value * i for i in self.index_vector]
updated_block = map(operator.add, anti_vector, updated_vector)
updated_block = list(map(operator.add, anti_vector, updated_vector))
return updated_block
class MultiSlicer(object):
def __init__(self, pack, index):
@@ -1685,7 +1686,7 @@ def test_oram(oram_type, N, value_type=sint, iterations=100):
value_type = value_type.get_type(32)
index_type = value_type.get_type(log2(N))
start_grind()
print 'initialized'
print('initialized')
print_ln('initialized')
stop_timer()
# synchronize
@@ -1718,7 +1719,7 @@ def test_oram(oram_type, N, value_type=sint, iterations=100):
def test_oram_access(oram_type, N, value_type=sint, index_size=None, iterations=100):
oram = oram_type(N, value_type=value_type, entry_size=32, \
init_rounds=0)
print 'initialized'
print('initialized')
print_reg(cint(0), 'init')
stop_timer()
# synchronize
@@ -1731,11 +1732,11 @@ def test_oram_access(oram_type, N, value_type=sint, index_size=None, iterations=
def f(i):
oram.access(value_type(i % N), value_type(0), value_type(True))
oram.access(value_type(i % N), value_type(i % N), value_type(True))
print 'first write'
print('first write')
time()
x = oram.access(value_type(i % N), value_type(0), value_type(False))
x[0][0].reveal().print_reg('writ')
print 'first read'
print('first read')
# @for_range(iterations)
# def f(i):
# x = oram.access(value_type(i % N), value_type(0), value_type(False), \
@@ -1747,7 +1748,7 @@ def test_oram_access(oram_type, N, value_type=sint, index_size=None, iterations=
def test_batch_init(oram_type, N):
value_type = sint
oram = oram_type(N, value_type)
print 'initialized'
print('initialized')
print_reg(cint(0), 'init')
oram.batch_init([value_type(i) for i in range(N)])
print_reg(cint(0), 'done')

View File

@@ -1,9 +1,10 @@
if '_Array' not in dir():
from oram import *
import permutation
from Compiler.oram import *
from Compiler import permutation
_Array = Array
import oram
from Compiler import oram
from functools import reduce
#import pdb
@@ -140,7 +141,7 @@ class PathORAM(TreeORAM):
bucket_size=2, init_rounds=-1):
#if size <= k:
# raise CompilerError('ORAM size too small')
print 'create oram of size', size
print('create oram of size', size)
self.bucket_oram = bucket_oram
self.bucket_size = bucket_size
self.D = log2(size)
@@ -240,7 +241,7 @@ class PathORAM(TreeORAM):
self.state.write(self.value_type(leaf))
print 'eviction leaf =', leaf
print('eviction leaf =', leaf)
# load the path
for i, ram_indices in enumerate(self.bucket_indices_on_path_to(leaf)):
@@ -325,7 +326,7 @@ class PathORAM(TreeORAM):
# at most one 1 in found
empty = 1 - sum(found)
prod_entries = map(operator.mul, found, entries)
prod_entries = list(map(operator.mul, found, entries))
read_value = sum((entry.x.skip(skip) for entry in prod_entries), \
empty * empty_entry.x.skip(skip))
for i,(j, entry, prod_entry) in enumerate(zip(ram_indices, entries, prod_entries)):
@@ -528,7 +529,7 @@ class PathORAM(TreeORAM):
values = (ValueTuple(x) for x in zip(*self.read_value))
not_empty = [1 - x for x in self.read_empty]
read_empty = 1 - sum(not_empty)
read_value = sum(map(operator.mul, not_empty, values), \
read_value = sum(list(map(operator.mul, not_empty, values)), \
ValueTuple(0 for i in range(self.value_length)))
self.check(u)
Program.prog.curr_tape.\
@@ -545,7 +546,7 @@ class PathORAM(TreeORAM):
yield bucket
def bucket_indices_on_path_to(self, leaf):
leaf = regint(leaf)
yield range(self.bucket_size)
yield list(range(self.bucket_size))
index = 0
for i in range(self.D):
index = 2*index + 1 + regint(cint(leaf) & 1)
@@ -742,7 +743,7 @@ class PathORAM(TreeORAM):
try:
self.stash.add(e)
except Exception:
print self
print(self)
raise
if evict:
self.evict()

View File

@@ -69,7 +69,7 @@ def odd_even_merge(a, comp):
odd_even_merge(even, comp)
odd_even_merge(odd, comp)
a[0] = even[0]
for i in range(1, len(a) / 2):
for i in range(1, len(a) // 2):
a[2*i-1], a[2*i] = cond_swap(odd[i-1], even[i], comp)
a[-1] = odd[-1]
@@ -77,8 +77,8 @@ def odd_even_merge_sort(a, comp=bitwise_comparator):
if len(a) == 1:
return
elif len(a) % 2 == 0:
lower = a[:len(a)/2]
upper = a[len(a)/2:]
lower = a[:len(a)//2]
upper = a[len(a)//2:]
odd_even_merge_sort(lower, comp)
odd_even_merge_sort(upper, comp)
a[:] = lower + upper
@@ -137,7 +137,7 @@ def random_perm(n):
if not Program.prog.options.insecure:
raise CompilerError('no secure implementation of Waksman permution, '
'use --insecure to activate')
a = range(n)
a = list(range(n))
for i in range(n-1, 0, -1):
j = randint(0, i)
t = a[i]
@@ -155,10 +155,10 @@ def configure_waksman(perm):
n = len(perm)
if n == 2:
return [(perm[0], perm[0])]
I = [None] * (n/2)
O = [None] * (n/2)
p0 = [None] * (n/2)
p1 = [None] * (n/2)
I = [None] * (n//2)
O = [None] * (n//2)
p0 = [None] * (n//2)
p1 = [None] * (n//2)
inv_perm = [0] * n
for i, p in enumerate(perm):
@@ -170,7 +170,7 @@ def configure_waksman(perm):
except ValueError:
break
#print 'j =', j
O[j/2] = 0
O[j//2] = 0
via = 0
j0 = j
while True:
@@ -178,10 +178,10 @@ def configure_waksman(perm):
i = inv_perm[j]
#print ' p0[%d] = %d' % (inv_perm[j]/2, j/2)
p0[i/2] = j/2
p0[i//2] = j//2
I[i/2] = i % 2
O[j/2] = j % 2
I[i//2] = i % 2
O[j//2] = j % 2
#print ' O[%d] = %d' % (j/2, j % 2)
if i % 2 == 1:
i -= 1
@@ -198,7 +198,7 @@ def configure_waksman(perm):
j += 1
#j, via = set_swapper(O, i, via, perm)
#print ' p1[%d] = %d' % (i/2, perm[i]/2)
p1[i/2] = perm[i]/2
p1[i//2] = perm[i]//2
#print ' i = %d, j = %d' %(i,j)
if j == j0:
@@ -206,8 +206,8 @@ def configure_waksman(perm):
if None not in p0 and None not in p1:
break
assert sorted(p0) == range(n/2)
assert sorted(p1) == range(n/2)
assert sorted(p0) == list(range(n//2))
assert sorted(p1) == list(range(n//2))
p0_config = configure_waksman(p0)
p1_config = configure_waksman(p1)
return [I + O] + [a+b for a,b in zip(p0_config, p1_config)]
@@ -219,22 +219,22 @@ def waksman(a, config, depth=0, start=0, reverse=False):
a[0], a[1] = cond_swap_bit(a[0], a[1], config[depth][start])
return
a0 = [0] * (n/2)
a1 = [0] * (n/2)
for i in range(n/2):
a0 = [0] * (n//2)
a1 = [0] * (n//2)
for i in range(n//2):
if reverse:
a0[i], a1[i] = cond_swap_bit(a[2*i], a[2*i+1], config[depth][i + n/2 + start])
a0[i], a1[i] = cond_swap_bit(a[2*i], a[2*i+1], config[depth][i + n//2 + start])
else:
a0[i], a1[i] = cond_swap_bit(a[2*i], a[2*i+1], config[depth][i + start])
waksman(a0, config, depth+1, start, reverse)
waksman(a1, config, depth+1, start + n/2, reverse)
waksman(a1, config, depth+1, start + n//2, reverse)
for i in range(n/2):
for i in range(n//2):
if reverse:
a[2*i], a[2*i+1] = cond_swap_bit(a0[i], a1[i], config[depth][i + start])
else:
a[2*i], a[2*i+1] = cond_swap_bit(a0[i], a1[i], config[depth][i + n/2 + start])
a[2*i], a[2*i+1] = cond_swap_bit(a0[i], a1[i], config[depth][i + n//2 + start])
WAKSMAN_FUNCTIONS = {}
@@ -263,11 +263,11 @@ def iter_waksman(a, config, reverse=False):
outwards = 1 - inwards
sizeval = size
#for k in range(n/2):
@for_range_parallel(200, n/2)
#for k in range(n//2):
@for_range_parallel(200, n//2)
def f(k):
j = cint(k) % sizeval
i = (cint(k) - j)/sizeval
i = (cint(k) - j)//sizeval
base = 2*i*sizeval
in1, in2 = (base+j+j*inwards), (base+j+j*inwards+1*inwards+sizeval*outwards)
@@ -297,7 +297,7 @@ def iter_waksman(a, config, reverse=False):
# going into middle of network
@for_range(logn)
def f(i):
size.write(n/(2*nblocks))
size.write(n//(2*nblocks))
conf_address = MemValue(config.address + depth.read()*n)
do_round(size, conf_address, a.address, a2.address, 1)
@@ -307,20 +307,20 @@ def iter_waksman(a, config, reverse=False):
nblocks.write(nblocks*2)
depth.write(depth+1)
nblocks.write(nblocks/4)
nblocks.write(nblocks//4)
depth.write(depth-2)
# and back out
@for_range(logn-1)
def f(i):
size.write(n/(2*nblocks))
size.write(n//(2*nblocks))
conf_address = MemValue(config.address + depth.read()*n)
do_round(size, conf_address, a.address, a2.address, 0)
for i in range(n):
a[i] = a2[i]
nblocks.write(nblocks/2)
nblocks.write(nblocks//2)
depth.write(depth-1)
## going into middle of network
@@ -375,7 +375,7 @@ def config_shuffle(n, value_type):
if n & (n-1) != 0:
# pad permutation to power of 2
m = 2**int(math.ceil(math.log(n, 2)))
perm += range(n, m)
perm += list(range(n, m))
config_bits = configure_waksman(perm)
# 2-D array
config = Array(len(config_bits) * len(perm), value_type.reg_type)

View File

@@ -3,15 +3,17 @@ from Compiler.exceptions import *
from Compiler.instructions_base import RegType
import Compiler.instructions
import Compiler.instructions_base
import compilerLib
import allocator as al
from . import compilerLib
from . import allocator as al
from . import util
import random
import time
import sys, os, errno
import inspect
from collections import defaultdict
from collections import defaultdict, deque
import itertools
import math
from functools import reduce
data_types = dict(
@@ -50,11 +52,11 @@ class Program(object):
self.bit_length = int(options.binary) or int(options.field)
if not self.bit_length:
self.bit_length = BIT_LENGTHS[param]
print 'Default bit length:', self.bit_length
print('Default bit length:', self.bit_length)
self.security = 40
print 'Default security parameter:', self.security
print('Default security parameter:', self.security)
self.galois_length = int(options.galois)
print 'Galois length:', self.galois_length
print('Galois length:', self.galois_length)
self.schedule = [('start', [])]
self.tape_counter = 0
self.tapes = []
@@ -118,7 +120,7 @@ class Program(object):
running[tape] -= 1
else:
raise CompilerError('Invalid schedule action')
res = max(res, sum(running.itervalues()))
res = max(res, sum(running.values()))
return res
def init_names(self, args, assemblymode):
@@ -129,7 +131,7 @@ class Program(object):
else:
# assume source is in main SPDZ directory
self.programs_dir = sys.path[0] + '/Programs'
print 'Compiling program in', self.programs_dir
print('Compiling program in', self.programs_dir)
# create extra directories if needed
for dirname in ['Public-Input', 'Bytecode', 'Schedules']:
@@ -225,7 +227,7 @@ class Program(object):
def read_memory(self, filename):
""" Read the clear and shared memory from a file """
f = open(filename)
n = int(f.next())
n = int(next(f))
self.mem_c = [0]*n
self.mem_s = [0]*n
mem = self.mem_c
@@ -253,8 +255,8 @@ class Program(object):
""" Reset register and memory values. """
for tape in self.tapes:
tape.reset_registers()
self.mem_c = range(USER_MEM + TMP_MEM)
self.mem_s = range(USER_MEM + TMP_MEM)
self.mem_c = list(range(USER_MEM + TMP_MEM))
self.mem_s = list(range(USER_MEM + TMP_MEM))
def write_bytes(self, outfile=None):
""" Write all non-empty threads and schedule to files. """
@@ -265,7 +267,7 @@ class Program(object):
sch_filename = self.programs_dir + '/Schedules/%s.sch' % self.name
sch_file = open(sch_filename, 'w')
print 'Writing to', sch_filename
print('Writing to', sch_filename)
sch_file.write(str(self.max_par_tapes()) + '\n')
sch_file.write(str(len(nonempty_tapes)) + '\n')
sch_file.write(' '.join(tape.name for tape in nonempty_tapes) + '\n')
@@ -276,7 +278,7 @@ class Program(object):
for sch in self.schedule:
# schedule may still contain empty tapes: ignore these
tapes = filter(lambda x: not x[0].is_empty(), sch[1])
tapes = [x for x in sch[1] if not x[0].is_empty()]
# no empty line
if not tapes:
continue
@@ -358,7 +360,7 @@ class Program(object):
def malloc(self, size, mem_type, reg_type=None):
""" Allocate memory from the top """
if not isinstance(size, (int, long)):
if not isinstance(size, int):
raise CompilerError('size must be known at compile time')
if size == 0:
return
@@ -374,7 +376,7 @@ class Program(object):
addr = self.allocated_mem[mem_type]
self.allocated_mem[mem_type] += size
if len(str(addr)) != len(str(addr + size)):
print "Memory of type '%s' now of size %d" % (mem_type, addr + size)
print("Memory of type '%s' now of size %d" % (mem_type, addr + size))
self.allocated_mem_blocks[addr,mem_type] = size
return addr
@@ -387,11 +389,11 @@ class Program(object):
self.free_mem_blocks[size,mem_type].add(addr)
def finalize_memory(self):
import library
from . import library
self.curr_tape.start_new_basicblock(None, 'memory-usage')
# reset register counter to 0
self.curr_tape.init_registers()
for mem_type,size in self.allocated_mem.items():
for mem_type,size in list(self.allocated_mem.items()):
if size:
#print "Memory of type '%s' of size %d" % (mem_type, size)
if mem_type in self.types:
@@ -404,11 +406,11 @@ class Program(object):
def set_bit_length(self, bit_length):
self.bit_length = bit_length
print 'Changed bit length for comparisons etc. to', bit_length
print('Changed bit length for comparisons etc. to', bit_length)
def set_security(self, security):
self.security = security
print 'Changed statistical security for comparison etc. to', security
print('Changed statistical security for comparison etc. to', security)
def optimize_for_gc(self):
pass
@@ -500,9 +502,9 @@ class Tape:
#print 'Basic block %d jumps to %d (%d)' % (next_block_index, jump_index, offset)
def purge(self):
relevant = lambda inst: inst.add_usage.__func__ is not \
Compiler.instructions_base.Instruction.add_usage.__func__
self.usage_instructions = filter(relevant, self.instructions)
relevant = lambda inst: inst.add_usage is not \
Compiler.instructions_base.Instruction.add_usage
self.usage_instructions = list(filter(relevant, self.instructions))
del self.instructions
del self.defined_registers
self.purged = True
@@ -568,8 +570,8 @@ class Tape:
def unpurged(function):
def wrapper(self, *args, **kwargs):
if self.purged:
print '%s called on purged block %s, ignoring' % \
(function.__name__, self.name)
print('%s called on purged block %s, ignoring' % \
(function.__name__, self.name))
return
return function(self, *args, **kwargs)
return wrapper
@@ -577,13 +579,13 @@ class Tape:
@unpurged
def optimize(self, options):
if len(self.basicblocks) == 0:
print 'Tape %s is empty' % self.name
print('Tape %s is empty' % self.name)
return
if self.if_states:
raise CompilerError('Unclosed if/else blocks')
print 'Processing tape', self.name, 'with %d blocks' % len(self.basicblocks)
print('Processing tape', self.name, 'with %d blocks' % len(self.basicblocks))
for block in self.basicblocks:
al.determine_scope(block, options)
@@ -593,38 +595,38 @@ class Tape:
if (options.merge_opens and self.merge_opens) or options.dead_code_elimination:
for i,block in enumerate(self.basicblocks):
if len(block.instructions) > 0:
print 'Processing basic block %s, %d/%d, %d instructions' % \
print('Processing basic block %s, %d/%d, %d instructions' % \
(block.name, i, len(self.basicblocks), \
len(block.instructions))
len(block.instructions)))
# the next call is necessary for allocation later even without merging
merger = al.Merger(block, options, \
tuple(self.program.to_merge))
if options.dead_code_elimination:
if len(block.instructions) > 10000:
print 'Eliminate dead code...'
print('Eliminate dead code...')
merger.eliminate_dead_code()
if options.merge_opens and self.merge_opens:
if len(block.instructions) == 0:
block.used_from_scope = set()
block.defined_registers = set()
block.used_from_scope = util.set_by_id()
block.defined_registers = util.set_by_id()
continue
if len(block.instructions) > 10000:
print 'Merging instructions...'
print('Merging instructions...')
numrounds = merger.longest_paths_merge()
block.n_rounds = numrounds
block.n_to_merge = len(merger.open_nodes)
if numrounds > 0:
print 'Program requires %d rounds of communication' % numrounds
print('Program requires %d rounds of communication' % numrounds)
if merger.counter:
print 'Block requires', \
print('Block requires', \
', '.join('%d %s' % (y, x.__name__) \
for x, y in merger.counter.items())
for x, y in list(merger.counter.items())))
# free memory
merger = None
if options.dead_code_elimination:
block.instructions = filter(lambda x: x is not None, block.instructions)
block.instructions = [x for x in block.instructions if x is not None]
if not (options.merge_opens and self.merge_opens):
print 'Not merging instructions in tape %s' % self.name
print('Not merging instructions in tape %s' % self.name)
# add jumps
offset = 0
@@ -640,39 +642,44 @@ class Tape:
block.adjust_return()
# now remove any empty blocks (must be done after setting jumps)
self.basicblocks = filter(lambda x: len(x.instructions) != 0, self.basicblocks)
self.basicblocks = [x for x in self.basicblocks if len(x.instructions) != 0]
# allocate registers
reg_counts = self.count_regs()
if not options.noreallocate:
if self.program.verbose:
print 'Tape register usage:', dict(reg_counts)
print 'modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp])
print 'GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N])
print 'Re-allocating...'
print('Tape register usage:', dict(reg_counts))
print('modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp]))
print('GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N]))
print('Re-allocating...')
allocator = al.StraightlineAllocator(REG_MAX)
def alloc_loop(block):
def alloc(block):
for reg in sorted(block.used_from_scope,
key=lambda x: (x.reg_type, x.i)):
allocator.alloc_reg(reg, block.alloc_pool)
for child in block.children:
if child.instructions:
alloc_loop(child)
def alloc_loop(block):
left = deque([block])
while left:
block = left.popleft()
alloc(block)
for child in block.children:
if child.instructions:
left.append(child)
for i,block in enumerate(reversed(self.basicblocks)):
if len(block.instructions) > 10000:
print 'Allocating %s, %d/%d' % \
(block.name, i, len(self.basicblocks))
print('Allocating %s, %d/%d' % \
(block.name, i, len(self.basicblocks)))
if block.exit_condition is not None:
jump = block.exit_condition.get_relative_jump()
if isinstance(jump, (int,long)) and jump < 0 and \
if isinstance(jump, int) and jump < 0 and \
block.exit_block.scope is not None:
alloc_loop(block.exit_block.scope)
allocator.process(block.instructions, block.alloc_pool)
# offline data requirements
print 'Compile offline data requirements...'
print('Compile offline data requirements...')
self.req_num = self.req_tree.aggregate()
print 'Tape requires', self.req_num
print('Tape requires', self.req_num)
for req,num in sorted(self.req_num.items()):
if num == float('inf') or num >= 2 ** 32:
num = -1
@@ -706,8 +713,8 @@ class Tape:
Compiler.instructions.reqbl(bl,
add_to_prog=False))
if self.program.verbose:
print 'Tape requires prime bit length', self.req_bit_length['p']
print 'Tape requires galois bit length', self.req_bit_length['2']
print('Tape requires prime bit length', self.req_bit_length['p'])
print('Tape requires galois bit length', self.req_bit_length['2'])
@unpurged
def _get_instructions(self):
@@ -722,12 +729,12 @@ class Tape:
@unpurged
def get_bytes(self):
""" Get the byte encoding of the program as an actual string of bytes. """
return "".join(str(i.get_bytes()) for i in self._get_instructions() if i is not None)
return b"".join(i.get_bytes() for i in self._get_instructions() if i is not None)
@unpurged
def write_encoding(self, filename):
""" Write the readable encoding to a file. """
print 'Writing to', filename
print('Writing to', filename)
f = open(filename, 'w')
for line in self.get_encoding():
f.write(str(line) + '\n')
@@ -736,7 +743,7 @@ class Tape:
@unpurged
def write_str(self, filename):
""" Write the sequence of instructions to a file. """
print 'Writing to', filename
print('Writing to', filename)
f = open(filename, 'w')
n = 0
for block in self.basicblocks:
@@ -756,8 +763,8 @@ class Tape:
filename += '.bc'
if not 'Bytecode' in filename:
filename = self.program.programs_dir + '/Bytecode/' + filename
print 'Writing to', filename
f = open(filename, 'w')
print('Writing to', filename)
f = open(filename, 'wb')
f.write(self.get_bytes())
f.close()
@@ -785,9 +792,9 @@ class Tape:
super(Tape.ReqNum, self).__init__(lambda: 0, init)
def __add__(self, other):
res = Tape.ReqNum()
for i,count in self.items():
for i,count in list(self.items()):
res[i] += count
for i,count in other.items():
for i,count in list(other.items()):
res[i] += count
return res
def __mul__(self, other):
@@ -798,7 +805,7 @@ class Tape:
__rmul__ = __mul__
def set_all(self, value):
if value == float('inf') and self['all', 'inv'] > 0:
print 'Going to unknown from %s' % self
print('Going to unknown from %s' % self)
res = Tape.ReqNum()
for i in self:
res[i] = value
@@ -811,14 +818,14 @@ class Tape:
res[i] = max(self[i], other[i])
return res
def cost(self):
return sum(num * COST[req[0]][req[1]] for req,num in self.items() \
return sum(num * COST[req[0]][req[1]] for req,num in list(self.items()) \
if req[1] != 'input')
def __str__(self):
return ", ".join('%s inputs in %s from player %d' \
% (num, req[0], req[2]) \
if req[1] == 'input' \
else '%s %ss in %s' % (num, req[1], req[0]) \
for req,num in self.items())
for req,num in list(self.items()))
def __repr__(self):
return repr(dict(self))
@@ -853,8 +860,8 @@ class Tape:
n_rounds = res['all', 'round']
n_invs = res['all', 'inv']
if (n_invs / n_rounds) * 1000 < n_reps:
print self.nodes[0].blocks[0].name, 'blowing up rounds: ', \
'(%d / %d) ** 3 < %d' % (n_rounds, n_reps, n_invs)
print(self.nodes[0].blocks[0].name, 'blowing up rounds: ', \
'(%d / %d) ** 3 < %d' % (n_rounds, n_reps, n_invs))
except:
pass
return res
@@ -892,7 +899,7 @@ class Tape:
The 'value' property is for emulation.
"""
__slots__ = ["reg_type", "program", "i", "value", "_is_active", \
__slots__ = ["reg_type", "program", "i", "_is_active", \
"size", "vector", "vectorbase", "caller", \
"can_eliminate"]
@@ -925,7 +932,7 @@ class Tape:
else:
self.caller = None
if self.i % 1000000 == 0 and self.i > 0:
print "Initialized %d registers at" % self.i, time.asctime()
print("Initialized %d registers at" % self.i, time.asctime())
def set_size(self, size):
if self.size == size:

View File

@@ -2,11 +2,12 @@ from Compiler.program import Tape
from Compiler.exceptions import *
from Compiler.instructions import *
from Compiler.instructions_base import *
from floatingpoint import two_power
import comparison, floatingpoint
from .floatingpoint import two_power
from . import comparison, floatingpoint
import math
import util
from . import util
import operator
from functools import reduce
class ClientMessageType:
@@ -127,15 +128,15 @@ class _number(object):
return self * self
def __add__(self, other):
if other is 0 or other is 0L:
if other is 0:
return self
else:
return self.add(other)
def __mul__(self, other):
if other is 0 or other is 0L:
if other is 0:
return 0
elif other is 1 or other is 1L:
elif other is 1:
return self
else:
return self.mul(other)
@@ -301,7 +302,7 @@ class _register(Tape.Register, _number, _structure):
if isinstance(val, (tuple, list)):
size = len(val)
super(_register, self).__init__(reg_type, program.curr_tape, size=size)
if isinstance(val, (int, long)):
if isinstance(val, int):
self.load_int(val)
elif isinstance(val, (tuple, list)):
for i, x in enumerate(val):
@@ -374,7 +375,7 @@ class _clear(_register):
res = self.prep_res(other)
if isinstance(other, cls):
c_inst(res, self, other)
elif isinstance(other, (int, long)):
elif isinstance(other, int):
if self.in_immediate_range(other):
ci_inst(res, self, other)
else:
@@ -392,7 +393,7 @@ class _clear(_register):
def coerce_op(self, other, inst, reverse=False):
cls = self.__class__
res = cls()
if isinstance(other, (int, long)):
if isinstance(other, int):
other = cls(other)
elif not isinstance(other, cls):
return NotImplemented
@@ -414,14 +415,14 @@ class _clear(_register):
def __rsub__(self, other):
return self.clear_op(other, subc, subcfi, True)
def __div__(self, other):
def __truediv__(self, other):
return self.clear_op(other, divc, divci)
def __rdiv__(self, other):
def __rtruediv__(self, other):
return self.coerce_op(other, divc, True)
def __eq__(self, other):
if isinstance(other, (_clear,int,long)):
if isinstance(other, (_clear,int)):
return regint(self) == other
else:
return NotImplemented
@@ -493,12 +494,12 @@ class cint(_clear, _int):
ldi(self, val)
else:
max = 2**31 - 1
sign = abs(val) / val
sign = abs(val) // val
val = abs(val)
chunks = []
while val:
mod = val % max
val = (val - mod) / max
val = (val - mod) // max
chunks.append(mod)
sum = cint(sign * chunks.pop())
for i,chunk in enumerate(reversed(chunks)):
@@ -520,13 +521,13 @@ class cint(_clear, _int):
return self.coerce_op(other, modc, True)
def __lt__(self, other):
if isinstance(other, (type(self),int,long)):
if isinstance(other, (type(self),int)):
return regint(self) < other
else:
return NotImplemented
def __gt__(self, other):
if isinstance(other, (type(self),int,long)):
if isinstance(other, (type(self),int)):
return regint(self) > other
else:
return NotImplemented
@@ -537,6 +538,23 @@ class cint(_clear, _int):
def __ge__(self, other):
return 1 - (self < other)
@vectorize
def __eq__(self, other):
if not isinstance(other, (_clear, int)):
return NotImplemented
res = 1
remaining = program.bit_length
while remaining > 0:
if isinstance(other, cint):
o = other.to_regint(min(remaining, 64))
else:
o = other % 2 ** 64
res *= (self.to_regint(min(remaining, 64)) == o)
self >>= 64
other >>= 64
remaining -= 64
return res
def __lshift__(self, other):
return self.clear_op(other, shlc, shlci)
@@ -683,7 +701,7 @@ class cgf2n(_clear, _gf2n):
def bit_decompose(self, bit_length=None, step=None):
bit_length = bit_length or program.galois_length
step = step or 1
res = [type(self)() for _ in range(bit_length / step)]
res = [type(self)() for _ in range(bit_length // step)]
gbitdec(self, step, *res)
return res
@@ -817,12 +835,15 @@ class regint(_register, _int):
def __neg__(self):
return 0 - self
def __div__(self, other):
def __floordiv__(self, other):
return self.int_op(other, divint)
def __rdiv__(self, other):
def __rfloordiv__(self, other):
return self.int_op(other, divint, True)
__truediv__ = __floordiv__
__rtruediv__ = __rfloordiv__
def __mod__(self, other):
return self - (self / other) * other
@@ -851,13 +872,13 @@ class regint(_register, _int):
return 1 - (self < other)
def __lshift__(self, other):
if isinstance(other, (int, long)):
if isinstance(other, int):
return self * 2**other
else:
return regint(cint(self) << other)
def __rshift__(self, other):
if isinstance(other, (int, long)):
if isinstance(other, int):
return self / 2**other
else:
return regint(cint(self) >> other)
@@ -911,6 +932,24 @@ class regint(_register, _int):
def print_if(self, string):
cint(self).print_if(string)
class localint(object):
""" Local integer that must prevented from leaking into the secure
computation. Uses regint internally. """
def __init__(self, value=None):
self._v = regint(value)
self.size = 1
def output(self):
self._v.print_reg_plain()
__lt__ = lambda self, other: localint(self._v < other)
__le__ = lambda self, other: localint(self._v <= other)
__gt__ = lambda self, other: localint(self._v > other)
__ge__ = lambda self, other: localint(self._v >= other)
__eq__ = lambda self, other: localint(self._v == other)
__ne__ = lambda self, other: localint(self._v != other)
class _secret(_register):
__slots__ = []
@@ -996,11 +1035,11 @@ class _secret(_register):
def matrix_mul(cls, A, B, n, res_params=None):
assert len(A) % n == 0
assert len(B) % n == 0
size = len(A) * len(B) / n**2
size = len(A) * len(B) // n**2
res = cls(size=size)
n_rows = len(A) / n
n_cols = len(B) / n
dotprods(*sum(([res[j], [A[j / n_cols * n + k] for k in range(n)],
n_rows = len(A) // n
n_cols = len(B) // n
dotprods(*sum(([res[j], [A[j // n_cols * n + k] for k in range(n)],
[B[k * n_cols + j % n_cols] for k in range(n)]]
for j in range(size)), []))
return res
@@ -1054,7 +1093,7 @@ class _secret(_register):
m_inst(res, other, self)
else:
m_inst(res, self, other)
elif isinstance(other, (int, long)):
elif isinstance(other, int):
if self.clear_type.in_immediate_range(other):
si_inst(res, self, other)
else:
@@ -1086,11 +1125,11 @@ class _secret(_register):
return self.secret_op(other, subs, submr, subsfi, True)
@vectorize
def __div__(self, other):
def __truediv__(self, other):
return self * (self.clear_type(1) / other)
@vectorize
def __rdiv__(self, other):
def __rtruediv__(self, other):
a,b = self.get_random_inverse()
return other * a / (a * self).reveal()
@@ -1253,7 +1292,7 @@ class sint(_secret, _int):
@vectorize
def __mod__(self, modulus):
if isinstance(modulus, (int, long)):
if isinstance(modulus, int):
l = math.log(modulus, 2)
if 2**int(round(l)) == modulus:
return self.mod2m(int(l))
@@ -1405,7 +1444,7 @@ class sgf2n(_secret, _gf2n):
return self ^ cgf2n(2**program.galois_length - 1)
def __xor__(self, other):
if other is 0 or other is 0L:
if other is 0:
return self
else:
return super(sgf2n, self).add(other)
@@ -1414,7 +1453,7 @@ class sgf2n(_secret, _gf2n):
@vectorize
def __and__(self, other):
if isinstance(other, (int, long)):
if isinstance(other, int):
other_bits = [(other >> i) & 1 \
for i in range(program.galois_length)]
else:
@@ -1515,7 +1554,7 @@ class _bitint(object):
else:
pre_op = floatingpoint.PreOpL
if d:
carries = zip(*pre_op(carry, [(0, carry_in)] + d))[1]
carries = list(zip(*pre_op(carry, [(0, carry_in)] + d)))[1]
else:
carries = []
res = lower + cls.sum_from_carries(a, b, carries)
@@ -1539,7 +1578,7 @@ class _bitint(object):
for k in range(m, -1, -1):
if sum(range(m, k - 1, -1)) + 1 >= n:
break
blocks = range(m, k, -1)
blocks = list(range(m, k, -1))
blocks.append(n - sum(blocks))
blocks.reverse()
#print 'blocks:', blocks
@@ -1597,9 +1636,9 @@ class _bitint(object):
@staticmethod
def get_highest_different_bits(a, b, index):
diff = [ai + bi for (ai,bi) in reversed(zip(a,b))]
diff = [ai + bi for (ai,bi) in reversed(list(zip(a,b)))]
preor = floatingpoint.PreOR(diff, raw=True)
highest_diff = [x - y for (x,y) in reversed(zip(preor, [0] + preor))]
highest_diff = [x - y for (x,y) in reversed(list(zip(preor, [0] + preor)))]
raw = sum(map(operator.mul, highest_diff, (a,b)[index]))
return raw.bit_decompose()[0]
@@ -1622,7 +1661,7 @@ class _bitint(object):
if type(other) == self.bin_type:
raise CompilerError('Unclear multiplication')
self_bits = self.bit_decompose()
if isinstance(other, (int, long)):
if isinstance(other, int):
other_bits = util.bit_decompose(other, self.n_bits)
bit_matrix = [[x * y for y in self_bits] for x in other_bits]
else:
@@ -1644,8 +1683,8 @@ class _bitint(object):
@classmethod
def wallace_tree_from_matrix(cls, bit_matrix, get_carry=True):
columns = [filter(None, (bit_matrix[j][i-j] \
for j in range(min(len(bit_matrix), i + 1)))) \
columns = [[_f for _f in (bit_matrix[j][i-j] \
for j in range(min(len(bit_matrix), i + 1))) if _f] \
for i in range(len(bit_matrix[0]))]
return cls.wallace_tree_from_columns(columns, get_carry)
@@ -1671,7 +1710,7 @@ class _bitint(object):
columns = new_columns[:-1]
for col in columns:
col.extend([0] * (2 - len(col)))
return self.bit_adder(*zip(*columns))
return self.bit_adder(*list(zip(*columns)))
@classmethod
def wallace_tree(cls, rows):
@@ -1685,17 +1724,17 @@ class _bitint(object):
d = [(1 + ai + bi, (1 - ai) * bi) for (ai,bi) in zip(a,b)]
borrow = lambda y,x,*args: \
(x[0] * y[0], 1 - (1 - x[1]) * (1 - x[0] * y[1]))
borrows = (0,) + zip(*floatingpoint.PreOpL(borrow, d))[1]
borrows = (0,) + list(zip(*floatingpoint.PreOpL(borrow, d)))[1]
return self.compose(ai + bi + borrow \
for (ai,bi,borrow) in zip(a,b,borrows))
def __rsub__(self, other):
raise NotImplementedError()
def __div__(self, other):
def __truediv__(self, other):
raise NotImplementedError()
def __rdiv__(self, other):
def __truerdiv__(self, other):
raise NotImplementedError()
def __lshift__(self, other):
@@ -1953,7 +1992,7 @@ class cfix(_number, _structure):
""" Clear fixed point type. """
__slots__ = ['value', 'f', 'k', 'size']
reg_type = 'c'
scalars = (int, long, float, regint)
scalars = (int, float, regint)
@classmethod
def set_precision(cls, f, k = None):
# k is the whole bitlength of fixed point
@@ -1978,7 +2017,7 @@ class cfix(_number, _structure):
if n == 1:
return cfix(cint_inputs)
else:
return map(cfix, cint_inputs)
return list(map(cfix, cint_inputs))
@vectorize
def write_to_socket(self, client_id, message_type=ClientMessageType.NoType):
@@ -1990,7 +2029,7 @@ class cfix(_number, _structure):
""" Send a list of cfix values to socket. Values are sent as bit shifted cints. """
def cfix_to_cint(fix_val):
return cint(fix_val.v)
cint_values = map(cfix_to_cint, values)
cint_values = list(map(cfix_to_cint, values))
writesocketc(client_id, message_type, *cint_values)
@staticmethod
@@ -2152,7 +2191,7 @@ class cfix(_number, _structure):
raise NotImplementedError
@vectorize
def __div__(self, other):
def __truediv__(self, other):
other = parse_type(other)
if isinstance(other, cfix):
return cfix(library.cint_cint_division(self.v, other.v, self.k, self.f))
@@ -2190,7 +2229,7 @@ class _single(_number, _structure):
""" Securely obtain shares of n values input by a client.
Assumes client has already run bit shift to convert fixed point to integer."""
sint_inputs = cls.int_type.receive_from_client(n, client_id, ClientMessageType.TripleShares)
return map(cls, sint_inputs)
return list(map(cls, sint_inputs))
@vectorized_classmethod
def load_mem(cls, address, mem_type=None):
@@ -2333,7 +2372,7 @@ class _fix(_single):
@classmethod
def coerce(cls, other):
if isinstance(other, (_fix, cfix)):
if isinstance(other, (_fix, cls.clear_type)):
return other
else:
return cls.conv(other)
@@ -2402,7 +2441,7 @@ class _fix(_single):
@vectorize
def mul(self, other):
if isinstance(other, (sint, cint, regint, int, long)):
if isinstance(other, (sint, cint, regint, int)):
return self._new(self.v * other, k=self.k, f=self.f)
elif isinstance(other, float):
if int(other) == other:
@@ -2413,13 +2452,11 @@ class _fix(_single):
f = self.f
while v % 2 == 0:
f -= 1
v /= 2
v //= 2
k = len(bin(abs(v))) - 1
other = cfix(cint(v))
other.f = f
other.k = k
other = self.multipliable(v, k, f)
other = self.coerce(other)
if isinstance(other, (_fix, cfix)):
if isinstance(other, (_fix, self.clear_type)):
val = self.v.TruncMul(other.v, self.k + other.k, other.f,
self.kappa,
self.round_nearest)
@@ -2438,7 +2475,7 @@ class _fix(_single):
return type(self)(-self.v)
@vectorize
def __div__(self, other):
def __truediv__(self, other):
other = self.coerce(other)
if isinstance(other, _fix):
return type(self)(library.FPDiv(self.v, other.v, self.k, self.f,
@@ -2450,7 +2487,7 @@ class _fix(_single):
raise TypeError('Incompatible fixed point types in division')
@vectorize
def __rdiv__(self, other):
def __rtruediv__(self, other):
return self.coerce(other) / self
@vectorize
@@ -2497,6 +2534,10 @@ class sfix(_fix):
def unreduced(self, v, other=None, res_params=None, n_summands=1):
return unreduced_sfix(v, self.k * 2, self.f, self.kappa)
@staticmethod
def multipliable(v, k, f):
return cfix(cint.conv(v), k, f)
class unreduced_sfix(_single):
int_type = sint
@@ -2511,7 +2552,7 @@ class unreduced_sfix(_single):
self.kappa = kappa
def __add__(self, other):
if other is 0 or other is 0L:
if other is 0:
return self
assert self.k == other.k
assert self.m == other.m
@@ -2643,7 +2684,7 @@ class _unreduced_squant(object):
self.res_params = res_params or params[0]
def __add__(self, other):
if other is 0 or other is 0L:
if other is 0:
return self
assert self.params == other.params
assert self.res_params == other.res_params
@@ -2807,10 +2848,10 @@ class sfloat(_number, _structure):
v = int(round(abs(v) * 2 ** (-p)))
if v == 2 ** vlen:
p += 1
v /= 2
v //= 2
z = 0
if p < -2 ** (plen - 1):
print 'Warning: %e truncated to zero' % vv
print('Warning: %e truncated to zero' % vv)
v, p, z = 0, 0, 1
if p >= 2 ** (plen - 1):
raise CompilerError('Cannot convert %s to float ' \
@@ -2950,8 +2991,8 @@ class sfloat(_number, _structure):
v = t
u = floatingpoint.BitDec(v, self.vlen + 2 + sfloat.round_nearest,
self.vlen + 2 + sfloat.round_nearest, self.kappa,
range(1 + sfloat.round_nearest,
self.vlen + 2 + sfloat.round_nearest))
list(range(1 + sfloat.round_nearest,
self.vlen + 2 + sfloat.round_nearest)))
# using u[0] doesn't seem necessary
h = floatingpoint.PreOR(u[:sfloat.round_nearest:-1], self.kappa)
p0 = self.vlen + 1 - sum(h)
@@ -3013,7 +3054,7 @@ class sfloat(_number, _structure):
def __rsub__(self, other):
return -self + other
def __div__(self, other):
def __truediv__(self, other):
other = self.conv(other)
v = floatingpoint.SDiv(self.v, other.v + other.z * (2**self.vlen - 1),
self.vlen, self.kappa, self.round_nearest)
@@ -3029,21 +3070,16 @@ class sfloat(_number, _structure):
sfloat.set_error(other.z)
return sfloat(v, p, z, s)
def __rdiv__(self, other):
def __rtruediv__(self, other):
return self.conv(other) / self
@vectorize
def __neg__(self):
return sfloat(self.v, self.p, self.z, (1 - self.s) * (1 - self.z))
def __abs__(self):
if self.s:
return -self
else:
return self
@vectorize
def __lt__(self, other):
other = self.conv(other)
if isinstance(other, sfloat):
z1 = self.z
z2 = other.z
@@ -3066,8 +3102,15 @@ class sfloat(_number, _structure):
def __ge__(self, other):
return 1 - (self < other)
def __gt__(self, other):
return self.conv(other) < self
def __le__(self, other):
return self.conv(other) >= self
@vectorize
def __eq__(self, other):
other = self.conv(other)
# the sign can be both ways for zeroes
both_zero = self.z * other.z
return floatingpoint.EQZ(self.v - other.v, self.vlen, self.kappa) * \
@@ -3151,24 +3194,25 @@ class Array(object):
program.free(self.address, self.value_type.reg_type)
def get_address(self, index):
key = str(index)
if isinstance(index, int) and self.length is not None:
index += self.length * (index < 0)
if index >= self.length or index < 0:
raise IndexError('index %s, length %s' % \
(str(index), str(self.length)))
if (program.curr_block, index) not in self.address_cache:
if (program.curr_block, key) not in self.address_cache:
n = self.value_type.n_elements()
length = self.length
if n == 1:
# length can be None for single-element arrays
length = 0
self.address_cache[program.curr_block, index] = \
self.address_cache[program.curr_block, key] = \
util.untuplify([self.address + index + i * length \
for i in range(n)])
if self.debug:
library.print_ln_if(index >= self.length, 'OF:' + self.debug)
library.print_ln_if(self.address_cache[program.curr_block, index] >= program.allocated_mem[self.value_type.reg_type], 'AOF:' + self.debug)
return self.address_cache[program.curr_block, index]
library.print_ln_if(self.address_cache[program.curr_block, key] >= program.allocated_mem[self.value_type.reg_type], 'AOF:' + self.debug)
return self.address_cache[program.curr_block, key]
def get_slice(self, index):
if index.stop is None and self.length is None:
@@ -3178,7 +3222,7 @@ class Array(object):
def __getitem__(self, index):
if isinstance(index, slice):
start, stop, step = self.get_slice(index)
res_length = (stop - start - 1) / step + 1
res_length = (stop - start - 1) // step + 1
res = Array(res_length, self.value_type)
@library.for_range(res_length)
def f(i):
@@ -3303,7 +3347,7 @@ class SubMultiArray(object):
def __getitem__(self, index):
if util.is_constant(index) and index >= self.sizes[0]:
raise StopIteration
key = program.curr_block, index
key = program.curr_block, str(index)
if key not in self.sub_cache:
if self.debug:
library.print_ln_if(index >= self.sizes[0], \
@@ -3531,7 +3575,7 @@ class _mem(_number):
__add__ = lambda self,other: self.read() + other
__sub__ = lambda self,other: self.read() - other
__mul__ = lambda self,other: self.read() * other
__div__ = lambda self,other: self.read() / other
__truediv__ = lambda self,other: self.read() / other
__mod__ = lambda self,other: self.read() % other
__pow__ = lambda self,other: self.read() ** other
__neg__ = lambda self,other: -self.read()
@@ -3550,7 +3594,7 @@ class _mem(_number):
__radd__ = lambda self,other: other + self.read()
__rsub__ = lambda self,other: other - self.read()
__rmul__ = lambda self,other: other * self.read()
__rdiv__ = lambda self,other: other / self.read()
__rtruediv__ = lambda self,other: other / self.read()
__rmod__ = lambda self,other: other % self.read()
__rand__ = lambda self,other: other & self.read()
__rxor__ = lambda self,other: other ^ self.read()
@@ -3627,7 +3671,7 @@ class MemValue(_mem):
self.check()
if isinstance(value, MemValue):
self.register = value.read()
elif isinstance(value, (int,long)):
elif isinstance(value, int):
self.register = self.value_type(value)
else:
self.register = value
@@ -3717,7 +3761,7 @@ def getNamedTupleType(*names):
class NamedTuple(object):
class NamedTupleArray(object):
def __init__(self, size, t):
import types
from . import types
self.arrays = [types.Array(size, t) for i in range(len(names))]
def __getitem__(self, index):
return NamedTuple(array[index] for array in self.arrays)
@@ -3749,4 +3793,4 @@ def getNamedTupleType(*names):
return self.__type__(x.reveal() for x in self)
return NamedTuple
import library
from . import library

View File

@@ -1,5 +1,6 @@
import math
import operator
from functools import reduce
def format_trace(trace, prefix=' '):
if trace is None:
@@ -46,7 +47,7 @@ def right_shift(a, b, bits):
return a.right_shift(b, bits)
def bit_decompose(a, bits):
if isinstance(a, (int,long)):
if isinstance(a, int):
return [int((a >> i) & 1) for i in range(bits)]
else:
return a.bit_decompose(bits)
@@ -82,7 +83,7 @@ def if_else(cond, a, b):
else:
return cond.if_else(a, b)
except:
print cond, a, b
print(cond, a, b)
raise
def cond_swap(cond, a, b):
@@ -112,8 +113,8 @@ def tree_reduce(function, sequence):
if n == 1:
return sequence[0]
else:
reduced = [function(sequence[2*i], sequence[2*i+1]) for i in range(n/2)]
return tree_reduce(function, reduced + sequence[n/2*2:])
reduced = [function(sequence[2*i], sequence[2*i+1]) for i in range(n//2)]
return tree_reduce(function, reduced + sequence[n//2*2:])
def or_op(a, b):
return a + b - a * b
@@ -144,7 +145,7 @@ def reveal(x):
return x
def is_constant(x):
return isinstance(x, (int, long, bool))
return isinstance(x, (int, bool))
def is_constant_float(x):
return isinstance(x, float) or is_constant(x)
@@ -180,3 +181,44 @@ def expand(x, size):
return x.expand_to_vector(size)
except AttributeError:
return x
class set_by_id(object):
def __init__(self, init=[]):
self.content = {}
for x in init:
self.add(x)
def __contains__(self, value):
return id(value) in self.content
def __iter__(self):
return iter(self.content.values())
def add(self, value):
self.content[id(value)] = value
class dict_by_id(object):
def __init__(self):
self.content = {}
def __contains__(self, key):
return id(key) in self.content
def __getitem__(self, key):
return self.content[id(key)][1]
def __setitem__(self, key, value):
self.content[id(key)] = (key, value)
def keys(self):
return (x[0] for x in self.content.values())
class defaultdict_by_id(dict_by_id):
def __init__(self, default):
dict_by_id.__init__(self)
self.default = default
def __getitem__(self, key):
if key not in self:
self[key] = self.default()
return dict_by_id.__getitem__(self, key)

66
ECDSA/EcdsaOptions.h Normal file
View File

@@ -0,0 +1,66 @@
/*
* EcdsaOptions.h
*
*/
#ifndef ECDSA_ECDSAOPTIONS_H_
#define ECDSA_ECDSAOPTIONS_H_
#include "Tools/ezOptionParser.h"
class EcdsaOptions
{
public:
bool prep_mul;
bool fewer_rounds;
bool check_open;
bool check_beaver_open;
EcdsaOptions(ez::ezOptionParser& opt, int argc, const char** argv)
{
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Delay multiplication until signing", // Help description.
"-D", // Flag token.
"--delay-multiplication" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Fewer rounds, more EC", // Help description.
"-P", // Flag token.
"--parallel-open" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Skip checking final openings (but not necessarily openings for Beaver; only relevant with active protocols)", // Help description.
"-C", // Flag token.
"--no-open-check" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Skip checking Beaver openings (only relevant with active protocols)", // Help description.
"-B", // Flag token.
"--no-beaver-open-check" // Flag token.
);
opt.parse(argc, argv);
prep_mul = not opt.isSet("-D");
fewer_rounds = opt.isSet("-P");
check_open = not opt.isSet("-C");
check_beaver_open = not opt.isSet("-B");
opt.resetArgs();
}
};
#endif /* ECDSA_ECDSAOPTIONS_H_ */

View File

@@ -18,7 +18,7 @@ int main()
string prefix = PREP_DIR "ECDSA/";
mkdir_p(prefix.c_str());
ofstream outf;
write_online_setup(outf, prefix, P256Element::Scalar::pr(), 0, false);
write_online_setup_without_init(outf, prefix, P256Element::Scalar::pr(), 0);
generate_mac_keys<Share<P256Element::Scalar>>(key, key2, 2, prefix);
make_mult_triples<Share<P256Element::Scalar>>(key, 2, 1000, false, prefix);
make_inverse<Share<P256Element::Scalar>>(key, 2, 1000, false, prefix);

View File

@@ -77,6 +77,14 @@ P256Element& P256Element::operator +=(const P256Element& other)
return *this;
}
P256Element& P256Element::operator /=(const Scalar& other)
{
auto tmp = other;
tmp.invert();
*this = *this * tmp;
return *this;
}
bool P256Element::operator ==(const P256Element& other) const
{
return point == other.point;

View File

@@ -10,14 +10,10 @@
#include "Math/gfp.h"
#if GFP_MOD_SZ != 4
#error GFP_MOD_SZ must be 4
#endif
class P256Element : public ValueInterface
{
public:
typedef gfp Scalar;
typedef gfp_<2, 4> Scalar;
private:
static CryptoPP::DL_GroupParameters_EC<CryptoPP::ECP> params;
@@ -52,6 +48,7 @@ public:
P256Element operator*(const Scalar& other) const;
P256Element& operator+=(const P256Element& other);
P256Element& operator/=(const Scalar& other);
bool operator==(const P256Element& other) const;
bool operator!=(const P256Element& other) const;

View File

@@ -5,8 +5,7 @@ in `preprocessing.hpp` and `sign.hpp`, respectively.
#### Compilation
- Add `MOD = -DGFP_MOD_SZ=4` to `CONFIG.mine`
- Also consider adding either `CXX = clang++` or `OPTIM = -O2` because GCC 8 or later with `-O3` will produce a segfault when using `mascot-ecdsa-party.x`
- Add either `CXX = clang++` or `OPTIM = -O2` because GCC 8 or later with `-O3` will produce a segfault when using `mascot-ecdsa-party.x`
- For older hardware, also add `ARCH = -march=native`
- Install [Crypto++](https://www.cryptopp.com) (`libcrypto++-dev` on Ubuntu). We used version 5.6.4, which is the default on Ubuntu 18.04.
- Compile the binaries: `make -j8 ecdsa`

View File

@@ -23,6 +23,7 @@
int main(int argc, const char** argv)
{
ez::ezOptionParser opt;
EcdsaOptions opts(opt, argc, argv);
Names N(opt, argc, argv, 2);
int n_tuples = 1000;
if (not opt.lastArgs.empty())
@@ -38,7 +39,7 @@ int main(int argc, const char** argv)
typedef Share<P256Element::Scalar> pShare;
DataPositions usage;
Sub_Data_Files<pShare> prep(N, prefix, usage);
typename pShare::MAC_Check MCp(keyp);
typename pShare::Direct_MC MCp(keyp);
ArithmeticProcessor _({}, 0);
SubProcessor<pShare> proc(_, MCp, prep, P);
@@ -46,7 +47,7 @@ int main(int argc, const char** argv)
proc.DataF.get_two(DATA_INVERSE, sk, __);
vector<EcTuple<Share>> tuples;
preprocessing(tuples, n_tuples, sk, proc);
preprocessing(tuples, n_tuples, sk, proc, opts);
check(tuples, sk, keyp, P);
sign_benchmark(tuples, sk, MCp, P);
sign_benchmark(tuples, sk, MCp, P, opts);
}

View File

@@ -29,15 +29,7 @@ void run(int argc, const char** argv)
{
bigint::init_thread();
ez::ezOptionParser opt;
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Delay multiplication until signing", // Help description.
"-D", // Flag token.
"--delay-multiplication" // Flag token.
);
EcdsaOptions opts(opt, argc, argv);
Names N(opt, argc, argv, 3);
int n_tuples = 1000;
if (not opt.lastArgs.empty())
@@ -51,10 +43,12 @@ void run(int argc, const char** argv)
P.Broadcast_Receive(bundle, false);
Timer timer;
timer.start();
auto stats = P.comm_stats;
pShare sk = typename T<P256Element::Scalar>::Honest::Protocol(P).get_random();
cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl;
(P.comm_stats - stats).print(true);
OnlineOptions::singleton.batch_size = n_tuples;
OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples;
DataPositions usage;
auto& prep = *Preprocessing<pShare>::get_live_prep(0, usage);
typename pShare::MAC_Check MCp;
@@ -63,9 +57,9 @@ void run(int argc, const char** argv)
bool prep_mul = not opt.isSet("-D");
vector<EcTuple<T>> tuples;
preprocessing(tuples, n_tuples, sk, proc, prep_mul);
preprocessing(tuples, n_tuples, sk, proc, opts);
// check(tuples, sk, {}, P);
sign_benchmark(tuples, sk, MCp, P, prep_mul ? 0 : &proc);
sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc);
delete &prep;
}

View File

@@ -25,27 +25,72 @@ template<template<class U> class T>
void run(int argc, const char** argv)
{
ez::ezOptionParser opt;
EcdsaOptions opts(opt, argc, argv);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Delay multiplication until signing", // Help description.
"-D", // Flag token.
"--delay-multiplication" // Flag token.
"Use SimpleOT instead of OT extension", // Help description.
"-S", // Flag token.
"--simple-ot" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Don't check correlation in OT extension (only relevant with MASCOT)", // Help description.
"-U", // Flag token.
"--unchecked-correlation" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Fewer rounds for authentication (only relevant with MASCOT)", // Help description.
"-A", // Flag token.
"--auth-fewer-rounds" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Use Fiat-Shamir for amplification (only relevant with MASCOT)", // Help description.
"-H", // Flag token.
"--fiat-shamir" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Skip sacrifice (only relevant with MASCOT)", // Help description.
"-E", // Flag token.
"--embrace-life" // Flag token.
);
opt.add(
"", // Default.
0, // Required?
0, // Number of args expected.
0, // Delimiter if expecting multiple args.
"No MACs (only relevant with MASCOT; implies skipping MAC checks)", // Help description.
"-M", // Flag token.
"--no-macs" // Flag token.
);
Names N(opt, argc, argv, 2);
int n_tuples = 1000;
if (not opt.lastArgs.empty())
n_tuples = atoi(opt.lastArgs[0]->c_str());
PlainPlayer P(N);
P256Element::init();
gfp1::init_field(P256Element::Scalar::pr(), false);
P256Element::Scalar::next::init_field(P256Element::Scalar::pr(), false);
BaseMachine machine;
machine.ot_setups.resize(1);
for (int i = 0; i < 2; i++)
machine.ot_setups[0].push_back({P, true});
machine.ot_setups.push_back({P, true});
P256Element::Scalar keyp;
SeededPRNG G;
@@ -65,16 +110,28 @@ void run(int argc, const char** argv)
P.Broadcast_Receive(bundle, false);
Timer timer;
timer.start();
auto stats = P.comm_stats;
sk_prep.get_two(DATA_INVERSE, sk, __);
cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl;
(P.comm_stats - stats).print(true);
OnlineOptions::singleton.batch_size = n_tuples;
OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples;
typename pShare::LivePrep prep(0, usage);
prep.params.correlation_check &= not opt.isSet("-U");
prep.params.fewer_rounds = opt.isSet("-A");
prep.params.fiat_shamir = opt.isSet("-H");
prep.params.check = not opt.isSet("-E");
prep.params.generateMACs = not opt.isSet("-M");
opts.check_beaver_open &= prep.params.generateMACs;
opts.check_open &= prep.params.generateMACs;
SubProcessor<pShare> proc(_, MCp, prep, P);
typename pShare::prep_type::Direct_MC MCpp(keyp);
prep.triple_generator->MC = &MCpp;
bool prep_mul = not opt.isSet("-D");
prep.params.use_extension = not opt.isSet("-S");
vector<EcTuple<T>> tuples;
preprocessing(tuples, n_tuples, sk, proc, prep_mul);
preprocessing(tuples, n_tuples, sk, proc, opts);
//check(tuples, sk, keyp, P);
sign_benchmark(tuples, sk, MCp, P, prep_mul ? 0 : &proc);
sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc);
}

View File

@@ -7,6 +7,7 @@
#define ECDSA_PREPROCESSING_HPP_
#include "P256Element.h"
#include "EcdsaOptions.h"
#include "Processor/Data_Files.h"
#include "Protocols/ReplicatedPrep.h"
#include "Protocols/MaliciousShamirShare.h"
@@ -23,40 +24,66 @@ public:
template<template<class U> class T>
void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
T<P256Element::Scalar>& sk,
SubProcessor<T<P256Element::Scalar>>& proc, bool prep_mul = true)
SubProcessor<T<P256Element::Scalar>>& proc,
EcdsaOptions opts)
{
bool prep_mul = opts.prep_mul;
Timer timer;
timer.start();
Player& P = proc.P;
auto& prep = proc.DataF;
size_t start = P.sent + prep.data_sent();
auto stats = P.comm_stats + prep.comm_stats();
auto& extra_player = P;
auto& protocol = proc.protocol;
auto& MCp = proc.MC;
typedef T<typename P256Element::Scalar> pShare;
typedef T<P256Element> cShare;
vector<pShare> inv_ks;
vector<cShare> secret_Rs;
prep.buffer_triples();
vector<pShare> bs, cs;
for (int i = 0; i < buffer_size; i++)
{
pShare a, a_inv;
prep.get_two(DATA_INVERSE, a, a_inv);
inv_ks.push_back(a_inv);
secret_Rs.push_back(a);
pShare a, b, c;
prep.get_three(DATA_TRIPLE, a, b, c);
inv_ks.push_back(a);
bs.push_back(b);
cs.push_back(c);
}
vector<P256Element::Scalar> cs_opened;
MCp.POpen_Begin(cs_opened, cs, extra_player);
if (opts.fewer_rounds)
secret_Rs.insert(secret_Rs.begin(), bs.begin(), bs.end());
else
{
MCp.POpen_End(cs_opened, cs, extra_player);
for (int i = 0; i < buffer_size; i++)
secret_Rs.push_back(bs[i] / cs_opened[i]);
}
vector<P256Element> opened_Rs;
typename cShare::Direct_MC MCc(MCp.get_alphai());
MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player);
if (prep_mul)
{
protocol.init_mul(&proc);
for (int i = 0; i < buffer_size; i++)
protocol.prepare_mul(inv_ks[i], sk);
protocol.exchange();
protocol.start_exchange();
}
else
prep.buffer_triples();
vector<P256Element> opened_Rs;
typename cShare::MAC_Check MCc(MCp.get_alphai());
MCc.POpen(opened_Rs, secret_Rs, P);
MCc.Check(P);
MCp.Check(P);
if (opts.fewer_rounds)
MCp.POpen_End(cs_opened, cs, extra_player);
MCc.POpen_End(opened_Rs, secret_Rs, extra_player);
if (opts.fewer_rounds)
for (int i = 0; i < buffer_size; i++)
opened_Rs[i] /= cs_opened[i];
if (prep_mul)
protocol.stop_exchange();
if (opts.check_open)
MCc.Check(extra_player);
if (opts.check_open or opts.check_beaver_open)
MCp.Check(extra_player);
for (int i = 0; i < buffer_size; i++)
{
tuples.push_back(
@@ -68,6 +95,7 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
<< " seconds, throughput " << buffer_size / timer.elapsed() << ", "
<< 1e-3 * (P.sent + prep.data_sent() - start) / buffer_size
<< " kbytes per tuple" << endl;
(P.comm_stats + prep.comm_stats() - stats).print(true);
}
template<template<class U> class T>

View File

@@ -8,10 +8,10 @@
#include "hm-ecdsa-party.hpp"
template<>
Preprocessing<Rep3Share<gfp>>* Preprocessing<Rep3Share<gfp>>::get_live_prep(
SubProcessor<Rep3Share<gfp>>* proc, DataPositions& usage)
Preprocessing<Rep3Share<gfp2>>* Preprocessing<Rep3Share<gfp2>>::get_live_prep(
SubProcessor<Rep3Share<gfp2>>* proc, DataPositions& usage)
{
return new ReplicatedPrep<Rep3Share<gfp>>(proc, usage);
return new ReplicatedPrep<Rep3Share<gfp2>>(proc, usage);
}
int main(int argc, const char** argv)

View File

@@ -71,12 +71,12 @@ EcSignature sign(const unsigned char* message, size_t length,
prod = protocol.finalize_mul();
}
auto rx = tuple.R.x();
signature.s = MC.POpen(
signature.s = MC.open(
tuple.a * hash_to_scalar(message, length) + prod * rx, P);
cout << "Minimal signing took " << timer.elapsed() * 1e3 << " ms and sending "
<< (P.sent - start) << " bytes" << endl;
auto diff = (P.comm_stats - stats);
diff.print();
diff.print(true);
return signature;
}
@@ -112,26 +112,30 @@ void check(EcSignature signature, const unsigned char* message, size_t length,
template<template<class U> class T>
void sign_benchmark(vector<EcTuple<T>>& tuples, T<P256Element::Scalar> sk,
typename T<P256Element::Scalar>::MAC_Check& MCp, Player& P,
EcdsaOptions& opts,
SubProcessor<T<P256Element::Scalar>>* proc = 0)
{
unsigned char message[1024];
GlobalPRNG(P).get_octets(message, 1024);
typename T<P256Element>::MAC_Check MCc(MCp.get_alphai());
typename T<P256Element>::Direct_MC MCc(MCp.get_alphai());
// synchronize
Bundle<octetStream> bundle(P);
P.Broadcast_Receive(bundle, true);
Timer timer;
timer.start();
P256Element pk = MCc.POpen(sk, P);
auto stats = P.comm_stats;
P256Element pk = MCc.open(sk, P);
MCc.Check(P);
cout << "Public key generation took " << timer.elapsed() * 1e3 << " ms" << endl;
P.comm_stats.print();
(P.comm_stats - stats).print(true);
for (size_t i = 0; i < min(10lu, tuples.size()); i++)
{
check(sign(message, 1 << i, tuples[i], MCp, P, pk, sk, proc), message,
1 << i, pk);
if (not opts.check_open)
continue;
Timer timer;
timer.start();
auto& check_player = MCp.get_check_player(P);

View File

@@ -365,6 +365,7 @@ int main(int argc, char** argv)
// init static gfp
string prep_data_prefix = get_prep_dir(nparties, 128, gf2n::default_degree());
initialise_fields(prep_data_prefix);
bigint::init_thread();
// Generate session keys to decrypt data sent from each spdz engine (party)
vector<octet*> session_keys(nparties);

View File

@@ -39,6 +39,8 @@ class Ciphertext
// Rely on default copy assignment/constructor
word get_pk_id() { return pk_id; }
void set(const Rq_Element& a0, const Rq_Element& a1, word pk_id)
{ cc0=a0; cc1=a1; this->pk_id = pk_id; }
void set(const Rq_Element& a0, const Rq_Element& a1, const FHE_PK& pk);

View File

@@ -9,13 +9,20 @@
template <class FD>
Multiplier<FD>::Multiplier(int offset, PairwiseGenerator<FD>& generator) :
generator(generator), machine(generator.machine),
P(generator.P, offset),
num_players(generator.P.num_players()),
my_num(generator.P.my_num()),
Multiplier(offset, generator.machine, generator.P, generator.timers)
{
}
template <class FD>
Multiplier<FD>::Multiplier(int offset, PairwiseMachine& machine, Player& P,
map<string, Timer>& timers) :
machine(machine),
P(P, offset),
num_players(P.num_players()),
my_num(P.my_num()),
other_pk(machine.other_pks[(my_num + num_players - offset) % num_players]),
other_enc_alpha(machine.enc_alphas[(my_num + num_players - offset) % num_players]),
timers(generator.timers),
timers(timers),
C(machine.pk), mask(machine.pk),
product_share(machine.setup<FD>().FieldD), rc(machine.pk),
volatile_capacity(0)

View File

@@ -21,7 +21,6 @@ class PairwiseMachine;
template <class FD>
class Multiplier
{
PairwiseGenerator<FD>& generator;
PairwiseMachine& machine;
OffsetPlayer P;
int num_players, my_num;
@@ -39,6 +38,9 @@ class Multiplier
public:
Multiplier(int offset, PairwiseGenerator<FD>& generator);
Multiplier(int offset, PairwiseMachine& machine, Player& P,
map<string, Timer>& timers);
void multiply_and_add(Plaintext_<FD>& res, const Ciphertext& C,
const Plaintext_<FD>& b);
void multiply_and_add(Plaintext_<FD>& res, const Ciphertext& C,

View File

@@ -13,6 +13,8 @@ using namespace std;
#include "Access.h"
#include "Processor/FixInput.h"
#include "Processor/ProcessorBase.hpp"
namespace GC
{

View File

@@ -25,9 +25,9 @@ void SemiPrep::set_protocol(Beaver<SemiSecret>& protocol)
(void) protocol;
params.set_passive();
triple_generator = new SemiSecret::TripleGenerator(
thread.processor.machine.ot_setups.at(thread.thread_num).at(0),
thread.processor.machine.ot_setups.at(thread.thread_num).get_fresh(),
thread.master.N, thread.thread_num, thread.master.opts.batch_size,
1, params, thread.P);
1, params, {}, thread.P);
triple_generator->multi_threaded = false;
}

View File

@@ -95,7 +95,7 @@ ShareParty<T>::ShareParty(int argc, const char** argv, int default_batch_size) :
else
P = new PlainPlayer(this->N, 0xFFFF);
for (int i = 0; i < this->machine.nthreads; i++)
this->machine.ot_setups.push_back({{{*P, true}}});
this->machine.ot_setups.push_back({*P, true});
delete P;
}

View File

@@ -152,8 +152,7 @@ void ReplicatedSecret<U>::reveal(size_t n_bits, Clear& x)
auto& share = *this;
vector<BitVec> opened;
auto& party = ShareThread<U>::s();
party.MC->POpen_Begin(opened, {share}, *party.P);
party.MC->POpen_End(opened, {share}, *party.P);
party.MC->POpen(opened, {share}, *party.P);
x = IntBase(opened[0]);
}

View File

@@ -59,7 +59,7 @@ void ThreadMaster<T>::run()
if (T::needs_ot)
for (int i = 0; i < machine.nthreads; i++)
machine.ot_setups.push_back({{*P, true}, {*P, true}});
machine.ot_setups.push_back({*P, true});
for (int i = 0; i < machine.nthreads; i++)
threads.push_back(new_thread(i));

View File

@@ -31,18 +31,17 @@ void TinyPrep<T>::set_protocol(Beaver<T>& protocol)
params.generateMACs = true;
params.amplify = false;
params.check = false;
params.set_mac_key(thread.MC->get_alphai());
triple_generator = new typename T::TripleGenerator(
thread.processor.machine.ot_setups.at(thread.thread_num).at(0),
thread.processor.machine.ot_setups.at(thread.thread_num).get_fresh(),
thread.master.N, thread.thread_num,
thread.master.opts.batch_size,
1, params, thread.P);
1, params, thread.MC->get_alphai(), thread.P);
triple_generator->multi_threaded = false;
input_generator = new typename T::part_type::TripleGenerator(
thread.processor.machine.ot_setups.at(thread.thread_num).at(1),
thread.processor.machine.ot_setups.at(thread.thread_num).get_fresh(),
thread.master.N, thread.thread_num,
thread.master.opts.batch_size,
1, params, thread.P);
1, params, thread.MC->get_alphai(), thread.P);
input_generator->multi_threaded = false;
thread.MC->get_part_MC().set_prep(*this);
}

View File

@@ -264,11 +264,8 @@ OTMachine::OTMachine(int argc, const char** argv)
// convert baseOT selection bits to BitVector
// (not already BitVector due to legacy PVW code)
baseReceiverInput = bot.receiver_inputs;
baseReceiverInput.resize(nbase);
for (int i = 0; i < nbase; i++)
{
baseReceiverInput.set_bit(i, bot.receiver_inputs[i]);
}
}
OTMachine::~OTMachine()

View File

@@ -107,7 +107,7 @@ int spdz_main(int argc, const char** argv, ez::ezOptionParser& opt, bool live_pr
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"Maximum number of parties to send to at once", // Help description.
"-b", // Flag token.
"-B", // Flag token.
"--max-broadcast" // Flag token.
);
opt.add(

View File

@@ -118,6 +118,7 @@ TripleMachine::TripleMachine(int argc, const char** argv) :
opt.get("-l")->getInt(nloops);
generateBits = opt.get("-B")->isSet;
check = opt.get("-c")->isSet || generateBits;
correlation_check = opt.get("-c")->isSet;
generateMACs = opt.get("-m")->isSet || check;
amplify = opt.get("-a")->isSet || generateMACs;
primeField = opt.get("-P")->isSet;
@@ -143,21 +144,22 @@ TripleMachine::TripleMachine(int argc, const char** argv) :
// doesn't work with Montgomery multiplication
gfp1::init_field(p, false);
gfp::init_field(p, true);
gf2n_long::init_field(128);
PRNG G;
G.ReSeed();
mac_key2l.randomize(G);
mac_key2s.randomize(G);
mac_key2.randomize(G);
mac_keyp.randomize(G);
mac_keyz.randomize(G);
}
template<class T>
GeneratorThread* TripleMachine::new_generator(OTTripleSetup& setup, int i)
GeneratorThread* TripleMachine::new_generator(OTTripleSetup& setup, int i,
typename T::mac_key_type mac_key)
{
return new typename T::TripleGenerator(setup, N[i % nConnections], i,
nTriplesPerThread, nloops, *this);
nTriplesPerThread, nloops, *this, mac_key);
}
void TripleMachine::run()
@@ -180,24 +182,24 @@ void TripleMachine::run()
for (int i = 0; i < nthreads; i++)
{
if (primeField)
generators[i] = new_generator<Share<gfp>>(setup, i);
generators[i] = new_generator<Share<gfp>>(setup, i, mac_keyp);
else if (z2k)
{
if (z2k == 32 and z2s == 32)
generators[i] = new_generator<Spdz2kShare<32, 32>>(setup, i);
generators[i] = new_generator<Spdz2kShare<32, 32>>(setup, i, mac_keyz);
else if (z2k == 64 and z2s == 64)
generators[i] = new_generator<Spdz2kShare<64, 64>>(setup, i);
generators[i] = new_generator<Spdz2kShare<64, 64>>(setup, i, mac_keyz);
else if (z2k == 64 and z2s == 48)
generators[i] = new_generator<Spdz2kShare<64, 48>>(setup, i);
generators[i] = new_generator<Spdz2kShare<64, 48>>(setup, i, mac_keyz);
else if (z2k == 66 and z2s == 64)
generators[i] = new_generator<Spdz2kShare<66, 64>>(setup, i);
generators[i] = new_generator<Spdz2kShare<66, 64>>(setup, i, mac_keyz);
else if (z2k == 66 and z2s == 48)
generators[i] = new_generator<Spdz2kShare<66, 48>>(setup, i);
generators[i] = new_generator<Spdz2kShare<66, 48>>(setup, i, mac_keyz);
else
throw runtime_error("not compiled for k=" + to_string(z2k) + " and s=" + to_string(z2s));
}
else
generators[i] = new_generator<Share<gf2n>>(setup, i);
generators[i] = new_generator<Share<gf2n>>(setup, i, mac_key2);
}
ntriples = generators[0]->nTriples * nthreads;
cout <<"Setup generators\n";
@@ -251,10 +253,8 @@ void TripleMachine::run()
void TripleMachine::output_mac_keys()
{
if (z2k) {
write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyz, mac_key2l);
write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyz, mac_key2);
}
else if (gf2n::degree() > 64)
write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyp, mac_key2l);
else
write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyp, mac_key2s);
write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyp, mac_key2);
}

29
Machines/hemi-party.cpp Normal file
View File

@@ -0,0 +1,29 @@
/*
* hemi-party.cpp
*
*/
#include "Protocols/HemiShare.h"
#include "Math/gfp.h"
#include "Math/gf2n.h"
#include "FHE/P2Data.h"
#include "Tools/ezOptionParser.h"
#include "Player-Online.hpp"
#include "Protocols/HemiPrep.hpp"
#include "Processor/Data_Files.hpp"
#include "Processor/Instruction.hpp"
#include "Processor/Machine.hpp"
#include "Protocols/SemiPrep.hpp"
#include "Protocols/SemiInput.hpp"
#include "Protocols/MAC_Check_Base.hpp"
#include "Protocols/MAC_Check.hpp"
#include "Protocols/fake-stuff.hpp"
#include "Protocols/SemiMC.hpp"
#include "Protocols/Beaver.hpp"
int main(int argc, const char** argv)
{
ez::ezOptionParser opt;
spdz_main<HemiShare<gfp>, HemiShare<gf2n_short>>(argc, argv, opt);
}

View File

@@ -38,7 +38,7 @@ DEPS := $(wildcard */*.d)
all: gen_input online offline externalIO bmr yao replicated shamir real-bmr spdz2k-party.x brain-party.x semi-party.x semi2k-party.x semi-bin-party.x mascot-party.x tiny-party.x
ifeq ($(USE_NTL),1)
all: overdrive she-offline cowgear-party.x
all: overdrive she-offline cowgear-party.x hemi-party.x
endif
-include $(DEPS)
@@ -165,6 +165,7 @@ malicious-shamir-party.x: Machines/ShamirMachine.o
spdz2k-party.x: $(OT)
semi-party.x: $(OT)
semi2k-party.x: $(OT)
hemi-party.x: $(FHEOFFLINE)
cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o
mascot-party.x: Machines/SPDZ.o $(OT)
Player-Online.x: Machines/SPDZ.o $(OT)

View File

@@ -110,6 +110,13 @@ void generate_online_setup(ofstream& outf, string dirname, bigint& p, int lgp, i
}
void write_online_setup(ofstream& outf, string dirname, const bigint& p, int lg2, bool mont)
{
write_online_setup_without_init(outf, dirname, p, lg2);
gfp::init_field(p, mont);
init_gf2n(lg2);
}
void write_online_setup_without_init(ofstream& outf, string dirname, const bigint& p, int lg2)
{
if (p == 0)
throw runtime_error("prime cannot be 0");
@@ -132,9 +139,6 @@ void write_online_setup(ofstream& outf, string dirname, const bigint& p, int lg2
// Fix as a negative lg2 is a ``signal'' to choose slightly weaker
// LWE parameters
outf << abs(lg2) << endl;
gfp::init_field(p, mont);
init_gf2n(lg2);
}
void init_gf2n(int lg2)

View File

@@ -22,6 +22,7 @@ using namespace std;
// Create setup file for gfp and gf2n
void generate_online_setup(ofstream& outf, string dirname, bigint& p, int lgp, int lg2);
void write_online_setup(ofstream& outf, string dirname, const bigint& p, int lg2, bool mont = true);
void write_online_setup_without_init(ofstream& outf, string dirname, const bigint& p, int lg2);
// Setup primes only
// Chooses a p of at least lgp bits

View File

@@ -15,14 +15,14 @@ void Square<gf2n_short>::to(gf2n_short& result)
result = sum;
}
template <>
void Square<gfp1>::to(gfp1& result)
template<class U>
template<int X, int L>
void Square<U>::to(gfp_<X, L>& result)
{
const int L = gfp1::N_LIMBS;
mp_limb_t product[2 * L], sum[2 * L], tmp[L][2 * L];
memset(tmp, 0, sizeof(tmp));
memset(sum, 0, sizeof(sum));
for (int i = 0; i < gfp1::length(); i++)
for (int i = 0; i < gfp_<X, L>::length(); i++)
{
memcpy(&(tmp[i/64][i/64]), &(rows[i]), sizeof(rows[i]));
if (i % 64 == 0)
@@ -32,10 +32,22 @@ void Square<gfp1>::to(gfp1& result)
mpn_add_fixed_n<2 * L>(sum, product, sum);
}
mp_limb_t q[2 * L], ans[2 * L];
mpn_tdiv_qr(q, ans, 0, sum, 2 * L, gfp1::get_ZpD().get_prA(), L);
mpn_tdiv_qr(q, ans, 0, sum, 2 * L, gfp_<X, L>::get_ZpD().get_prA(), L);
result.assign((void*) ans);
}
template<>
void Square<gfp1>::to(gfp1& result)
{
to<1, GFP_MOD_SZ>(result);
}
template<>
void Square<gfp3>::to(gfp3& result)
{
to<3, 4>(result);
}
template<>
void Square<BitVec>::to(BitVec& result)
{

View File

@@ -31,6 +31,8 @@ public:
void conditional_add(BitVector& conditions, Square& other,
int offset);
void to(U& result);
template<int X, int L>
void to(gfp_<X, L>& result);
void pack(octetStream& os) const;
void unpack(octetStream& os);

View File

@@ -283,7 +283,8 @@ Z2<K> Z2<K>::operator>>(int i) const
{
Z2<K> res;
int n_byte_shift = i / 8;
memcpy(res.a, (char*)a + n_byte_shift, N_BYTES - n_byte_shift);
if (N_BYTES - n_byte_shift > 0)
memcpy(res.a, (char*)a + n_byte_shift, N_BYTES - n_byte_shift);
int n_inside_shift = i % 8;
if (n_inside_shift > 0)
mpn_rshift(res.a, res.a, N_WORDS, n_inside_shift);

View File

@@ -147,4 +147,17 @@ ostream& operator<<(ostream& o, const Z2<K>& x)
return o;
}
template<int K>
istream& operator>>(istream& i, SignedZ2<K>& x)
{
auto& tmp = bigint::tmp;
i >> tmp;
if (tmp.numBits() > K + 1)
throw runtime_error(
tmp.get_str() + " out of range for signed " + to_string(K)
+ "-bit numbers");
x = tmp;
return i;
}
#endif

View File

@@ -262,6 +262,8 @@ typedef gfp_<0, GFP_MOD_SZ> gfp;
typedef gfp_<1, GFP_MOD_SZ> gfp1;
// enough for Brain protocol with 64-bit computation and 40-bit security
typedef gfp_<2, 4> gfp2;
// for OT-based ECDSA
typedef gfp_<3, 4> gfp3;
void to_signed_bigint(bigint& ans,const gfp& x);

View File

@@ -110,7 +110,7 @@ inline mp_limb_t mpn_add_n_with_carry(mp_limb_t* res, const mp_limb_t* x, const
char carry = 0;
for (int i = 0; i < n; i++)
#if defined(__clang__)
#if __clang_major__ < 8 || defined(__APPLE__)
#if __clang_major__ < 8 || (defined(__APPLE__) && __clang_major__ < 11)
carry = __builtin_ia32_addcarry_u64(carry, x[i], y[i], (unsigned long long*)&res[i]);
#else
carry = __builtin_ia32_addcarryx_u64(carry, x[i], y[i], (unsigned long long*)&res[i]);

View File

@@ -97,12 +97,20 @@ Names::Names(ez::ezOptionParser& opt, int argc, const char** argv,
1, // Required?
1, // Number of args expected.
0, // Delimiter if expecting multiple args.
"This player's number", // Help description.
"This player's number (required)", // Help description.
"-p", // Flag token.
"--player" // Flag token.
);
opt.parse(argc, argv);
opt.get("-p")->getInt(player_no);
vector<string> missing;
if (not opt.gotRequired(missing))
{
string usage;
opt.getUsage(usage);
cerr << usage;
exit(1);
}
global_server = network_opts.start_networking(*this, player_no);
}
@@ -123,7 +131,7 @@ void Names::setup_names(const char *servername, int my_port)
set_up_client_socket(socket_num, servername, pn);
send(socket_num, (octet*)&player_no, sizeof(player_no));
#ifdef DEBUG_NETWORKING
fprintf(stderr, "Sent %d to %s:%d\n", player_no, servername, pn);
cerr << "Sent " << player_no << " to " << servername << ":" << pn << endl;
#endif
int inst=-1; // wait until instruction to start.
@@ -338,6 +346,14 @@ void MultiPlayer<T>::send_all(const octetStream& o,bool donthash) const
}
void Player::receive_all(vector<octetStream>& os) const
{
for (int j = 0; j < num_players(); j++)
if (j != my_num())
receive_player(j, os[j], true);
}
void Player::receive_player(int i,octetStream& o,bool donthash) const
{
#ifdef VERBOSE_COMM
@@ -345,6 +361,7 @@ void Player::receive_player(int i,octetStream& o,bool donthash) const
#endif
TimeScope ts(timer);
receive_player_no_stats(i, o);
comm_stats["Receiving directly"].add(o, ts);
if (!donthash)
{ blk_SHA1_Update(&ctx,o.get_data(),o.get_length()); }
}
@@ -627,12 +644,14 @@ void RealTwoPartyPlayer::receive(octetStream& o) const
TimeScope ts(timer);
o.reset_write_head();
o.Receive(socket);
comm_stats["Receiving one-to-one"].add(o, ts);
}
void VirtualTwoPartyPlayer::receive(octetStream& o) const
{
TimeScope ts(timer);
P.receive_player_no_stats(other_player, o);
comm_stats["Receiving one-to-one"].add(o, ts);
}
void RealTwoPartyPlayer::send_receive_player(vector<octetStream>& o) const
@@ -688,6 +707,8 @@ void TwoPartyPlayer::Broadcast_Receive(vector<octetStream>& o,
CommStats& CommStats::operator +=(const CommStats& other)
{
data += other.data;
rounds += other.rounds;
timer += other.timer;
return *this;
}
@@ -698,6 +719,13 @@ NamedCommStats& NamedCommStats::operator +=(const NamedCommStats& other)
return *this;
}
NamedCommStats NamedCommStats::operator +(const NamedCommStats& other) const
{
auto res = *this;
res += other;
return res;
}
CommStats& CommStats::operator -=(const CommStats& other)
{
data -= other.data;
@@ -722,13 +750,15 @@ size_t NamedCommStats::total_data()
return res;
}
void NamedCommStats::print()
void NamedCommStats::print(bool newline)
{
for (auto it = begin(); it != end(); it++)
if (it->second.data)
cerr << it->first << " " << 1e-6 * it->second.data << " MB in "
<< it->second.rounds << " rounds, taking " << it->second.timer.elapsed()
<< " seconds" << endl;
if (size() and newline)
cerr << endl;
}
template class MultiPlayer<int>;

View File

@@ -91,6 +91,7 @@ struct CommStats
Timer timer;
CommStats() : data(0), rounds(0) {}
Timer& add(const octetStream& os) { data += os.get_length(); rounds++; return timer; }
void add(const octetStream& os, const TimeScope& scope) { add(os) += scope; }
CommStats& operator+=(const CommStats& other);
CommStats& operator-=(const CommStats& other);
};
@@ -99,9 +100,10 @@ class NamedCommStats : public map<string, CommStats>
{
public:
NamedCommStats& operator+=(const NamedCommStats& other);
NamedCommStats operator+(const NamedCommStats& other) const;
NamedCommStats operator-(const NamedCommStats& other) const;
size_t total_data();
void print();
void print(bool newline = false);
#ifdef VERBOSE_COMM
CommStats& operator[](const string& name)
{
@@ -160,6 +162,7 @@ public:
virtual void send_all(const octetStream& o,bool donthash=false) const = 0;
void send_to(int player,const octetStream& o,bool donthash=false) const;
virtual void send_to_no_stats(int player,const octetStream& o) const = 0;
void receive_all(vector<octetStream>& os) const;
void receive_player(int i,octetStream& o,bool donthash=false) const;
virtual void receive_player_no_stats(int i,octetStream& o) const = 0;
virtual void receive_player(int i,FlexBuffer& buffer) const;

View File

@@ -83,84 +83,40 @@ ServerSocket::~ServerSocket()
void ServerSocket::accept_clients()
{
map<int, sockaddr> unassigned_sockets;
while (true)
{
fd_set readfds;
FD_ZERO(&readfds);
int nfds = main_socket;
FD_SET(main_socket, &readfds);
for (auto &socket : unassigned_sockets)
{
FD_SET(socket.first, &readfds);
nfds = max(socket.first, nfds);
}
struct sockaddr dest;
memset(&dest, 0, sizeof(dest)); /* zero the struct before filling the fields */
int socksize = sizeof(dest);
int consocket = accept(main_socket, (struct sockaddr *)&dest, (socklen_t*) &socksize);
if (consocket<0) { error("set_up_socket:accept"); }
select(nfds + 1, &readfds, 0, 0, 0);
if (FD_ISSET(main_socket, &readfds))
{
struct sockaddr dest;
memset(&dest, 0, sizeof(dest)); /* zero the struct before filling the fields */
int socksize = sizeof(dest);
int consocket = accept(main_socket, (struct sockaddr*) &dest,
(socklen_t*) &socksize);
if (consocket < 0)
error("set_up_socket:accept");
unassigned_sockets[consocket] = dest;
int client_id;
try
{
receive(consocket, (unsigned char*)&client_id, sizeof(client_id));
}
catch (closed_connection&)
{
#ifdef DEBUG_NETWORKING
auto &conn = *(sockaddr_in*) &dest;
fprintf(stderr, "new client on %s:%d\n", inet_ntoa(conn.sin_addr),
ntohs(conn.sin_port));
auto& conn = *(sockaddr_in*)&dest;
cerr << "client on " << inet_ntoa(conn.sin_addr) << ":"
<< ntohs(conn.sin_port) << " left without identification"
<< endl;
#endif
}
}
vector<int> processed_sockets;
for (auto &socket : unassigned_sockets)
{
int consocket = socket.first;
if (FD_ISSET(consocket, &readfds))
{
try
{
int client_id;
receive(consocket, (unsigned char*) &client_id,
sizeof(client_id));
data_signal.lock();
clients[client_id] = consocket;
data_signal.broadcast();
data_signal.unlock();
#ifdef DEBUG_NETWORKING
auto &conn = *(sockaddr_in*) &socket.second;
fprintf(stderr, "client id %d on %s:%d\n", client_id,
inet_ntoa(conn.sin_addr), ntohs(conn.sin_port));
#endif
data_signal.lock();
clients[client_id] = consocket;
data_signal.broadcast();
data_signal.unlock();
#ifdef __APPLE__
int flags = fcntl(consocket, F_GETFL, 0);
int fl = fcntl(consocket, F_SETFL, O_NONBLOCK | flags);
if (fl < 0)
error("set non-blocking");
int flags = fcntl(consocket, F_GETFL, 0);
int fl = fcntl(consocket, F_SETFL, O_NONBLOCK | flags);
if (fl < 0)
error("set non-blocking");
#endif
}
catch (closed_connection&)
{
#ifdef DEBUG_NETWORKING
auto &conn = *(sockaddr_in*) &socket.second;
cerr << "client on " << inet_ntoa(conn.sin_addr) << ":"
<< ntohs(conn.sin_port) << " left without identification"
<< endl;
#endif
close_client_socket(consocket);
}
processed_sockets.push_back(consocket);
}
}
for (int socket : processed_sockets)
unassigned_sockets.erase(socket);
}
}

View File

@@ -116,7 +116,7 @@ void BaseOT::exec_base(bool new_receiver_inputs)
{
if (new_receiver_inputs)
receiver_inputs[i + j] = G.get_uchar()&1;
cs[j] = receiver_inputs[i + j];
cs[j] = receiver_inputs[i + j].get();
}
receiver_rsgen(&receiver, Rs_pack[0], cs);
os[0].store_bytes(Rs_pack[0], sizeof(Rs_pack[0]));
@@ -293,7 +293,7 @@ void FakeOT::exec_base(bool new_receiver_inputs)
{
for (int j = 0; j < 2; j++)
bv[j].unpack(os[1]);
receiver_outputs[i] = bv[receiver_inputs[i]];
receiver_outputs[i] = bv[receiver_inputs[i].get()];
}
set_seeds();

View File

@@ -30,7 +30,7 @@ void send_if_ot_receiver(TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE rol
class BaseOT
{
public:
vector<int> receiver_inputs;
BitVector receiver_inputs;
vector< vector<BitVector> > sender_inputs;
vector<BitVector> receiver_outputs;
TwoPartyPlayer* P;
@@ -63,7 +63,7 @@ public:
int length() { return ot_length; }
void set_receiver_inputs(const vector<int>& new_inputs)
void set_receiver_inputs(const BitVector& new_inputs)
{
if ((int)new_inputs.size() != nOT)
throw invalid_length();
@@ -72,7 +72,7 @@ public:
void set_receiver_inputs(int128 inputs)
{
vector<int> new_inputs(128);
BitVector new_inputs(128);
for (int i = 0; i < 128; i++)
new_inputs[i] = (inputs >> i).get_lower() & 1;
set_receiver_inputs(new_inputs);
@@ -81,6 +81,7 @@ public:
// do the OTs -- generate fresh random choice bits by default
virtual void exec_base(bool new_receiver_inputs=true);
// use PRG to get the next ot_length bits
void set_seeds();
void extend_length();
void check();
@@ -90,8 +91,6 @@ protected:
bool is_sender() { return (bool) (ot_role & SENDER); }
bool is_receiver() { return (bool) (ot_role & RECEIVER); }
void set_seeds();
};
class FakeOT : public BaseOT

View File

@@ -702,4 +702,5 @@ BMS
XXXX(Matrix<gf2n_short_square>, gf2n_short)
XXXX(Matrix<Square<gf2n_long>>, gf2n_long)
XXXX(Matrix<Square<gfp1>>, gfp1)
XXXX(Matrix<Square<gfp3>>, gfp3)
XXXX(Matrix<BitDiagonal>, BitVec)

View File

@@ -142,7 +142,7 @@ public:
BitMatrix() {}
BitMatrix(int length);
__m128i operator[](int i) { return squares[i / 128].rows[i % 128]; }
__m128i& operator[](int i) { return squares[i / 128].rows[i % 128]; }
void resize(int length);
int size();

View File

@@ -26,81 +26,15 @@ MascotParams::MascotParams()
generateMACs = true;
amplify = true;
check = true;
correlation_check = true;
generateBits = false;
use_extension = true;
fewer_rounds = false;
fiat_shamir = false;
timerclear(&start);
}
void MascotParams::set_passive()
{
generateMACs = amplify = check = false;
}
template<> gf2n_long MascotParams::get_mac_key()
{
return mac_key2l;
}
template<> gf2n_short MascotParams::get_mac_key()
{
return mac_key2s;
}
template<> gfp1 MascotParams::get_mac_key()
{
return mac_keyp;
}
template<> Z2<48> MascotParams::get_mac_key()
{
return mac_keyz;
}
template<> Z2<64> MascotParams::get_mac_key()
{
return mac_keyz;
}
template<> Z2<40> MascotParams::get_mac_key()
{
return mac_keyz;
}
template<> Z2<32> MascotParams::get_mac_key()
{
return mac_keyz;
}
template<> BitVec MascotParams::get_mac_key()
{
return 0;
}
template<> void MascotParams::set_mac_key(gf2n_long key)
{
mac_key2l = key;
}
template<> void MascotParams::set_mac_key(gf2n_short key)
{
mac_key2s = key;
}
template<> void MascotParams::set_mac_key(gfp1 key)
{
mac_keyp = key;
}
template<> void MascotParams::set_mac_key(Z2<64> key)
{
mac_keyz = key;
}
template<> void MascotParams::set_mac_key(Z2<48> key)
{
mac_keyz = key;
}
template<> void MascotParams::set_mac_key(Z2<40> key)
{
mac_keyz = key;
generateMACs = amplify = check = correlation_check = false;
}

View File

@@ -68,6 +68,8 @@ protected:
SeededPRNG share_prg;
mac_key_type mac_key;
void start_progress();
void print_progress(int k);
@@ -101,8 +103,11 @@ public:
vector<PlainTriple<open_type, N_AMPLIFY>> preampTriples;
vector<array<open_type, 3>> plainTriples;
OTTripleGenerator(OTTripleSetup& setup, const Names& names,
typename T::MAC_Check* MC;
OTTripleGenerator(const OTTripleSetup& setup, const Names& names,
int thread_num, int nTriples, int nloops, MascotParams& machine,
mac_key_type mac_key,
Player* parentPlayer = 0);
~OTTripleGenerator();
@@ -113,7 +118,10 @@ public:
void run_multipliers(MultJob job);
mac_key_type get_mac_key() const { return mac_key; }
size_t data_sent();
NamedCommStats comm_stats();
};
template<class T>
@@ -130,8 +138,9 @@ public:
vector< ShareTriple_<sacri_type, mac_key_type, 2> > uncheckedTriples;
vector<InputTuple<Share<sacri_type>>> inputs;
NPartyTripleGenerator(OTTripleSetup& setup, const Names& names,
NPartyTripleGenerator(const OTTripleSetup& setup, const Names& names,
int thread_num, int nTriples, int nloops, MascotParams& machine,
mac_key_type mac_key,
Player* parentPlayer = 0);
virtual ~NPartyTripleGenerator() {}
@@ -159,8 +168,9 @@ class MascotTripleGenerator : public NPartyTripleGenerator<T>
public:
vector<T> bits;
MascotTripleGenerator(OTTripleSetup& setup, const Names& names,
MascotTripleGenerator(const OTTripleSetup& setup, const Names& names,
int thread_num, int nTriples, int nloops, MascotParams& machine,
mac_key_type mac_key,
Player* parentPlayer = 0);
};
@@ -181,8 +191,9 @@ class Spdz2kTripleGenerator : public NPartyTripleGenerator<T>
U& MC, PRNG& G);
public:
Spdz2kTripleGenerator(OTTripleSetup& setup, const Names& names,
Spdz2kTripleGenerator(const OTTripleSetup& setup, const Names& names,
int thread_num, int nTriples, int nloops, MascotParams& machine,
mac_key_type mac_key,
Player* parentPlayer = 0);
void generateTriples();
@@ -199,4 +210,15 @@ size_t OTTripleGenerator<T>::data_sent()
return res;
}
template<class T>
NamedCommStats OTTripleGenerator<T>::comm_stats()
{
NamedCommStats res;
if (parentPlayer != &globalPlayer)
res = globalPlayer.comm_stats;
for (auto& player : players)
res += player->comm_stats;
return res;
}
#endif

View File

@@ -5,27 +5,14 @@
#include "OT/OTExtensionWithMatrix.h"
#include "OT/OTMultiplier.h"
#include "Math/gfp.h"
#include "Protocols/Share.h"
#include "Protocols/SemiShare.h"
#include "Protocols/Semi2kShare.h"
#include "Protocols/Spdz2kShare.h"
#include "Math/operators.h"
#include "Tools/Subroutines.h"
#include "Protocols/MAC_Check.h"
#include "Protocols/Spdz2kPrep.h"
#include "GC/SemiSecret.h"
#include "OT/Triple.hpp"
#include "OT/Rectangle.hpp"
#include "OT/OTMultiplier.hpp"
#include "Protocols/MAC_Check.hpp"
#include "Protocols/SemiMC.h"
#include "Protocols/MascotPrep.hpp"
#include "Protocols/ReplicatedInput.hpp"
#include "Protocols/SemiInput.hpp"
#include "Processor/Input.hpp"
#include "Math/Z2k.hpp"
#include <sstream>
#include <fstream>
@@ -43,44 +30,46 @@ void* run_ot_thread(void* ptr)
* N.B. setup must not be stored as it will be used by other threads
*/
template<class T>
NPartyTripleGenerator<T>::NPartyTripleGenerator(OTTripleSetup& setup,
NPartyTripleGenerator<T>::NPartyTripleGenerator(const OTTripleSetup& setup,
const Names& names, int thread_num, int _nTriples, int nloops,
MascotParams& machine, Player* parentPlayer) :
MascotParams& machine, mac_key_type mac_key, Player* parentPlayer) :
OTTripleGenerator<T>(setup, names, thread_num, _nTriples, nloops,
machine, parentPlayer)
machine, mac_key, parentPlayer)
{
}
template<class T>
MascotTripleGenerator<T>::MascotTripleGenerator(OTTripleSetup& setup,
MascotTripleGenerator<T>::MascotTripleGenerator(const OTTripleSetup& setup,
const Names& names, int thread_num, int _nTriples, int nloops,
MascotParams& machine, Player* parentPlayer) :
MascotParams& machine, mac_key_type mac_key, Player* parentPlayer) :
NPartyTripleGenerator<T>(setup, names, thread_num, _nTriples, nloops,
machine, parentPlayer)
machine, mac_key, parentPlayer)
{
}
template<class T>
Spdz2kTripleGenerator<T>::Spdz2kTripleGenerator(OTTripleSetup& setup,
Spdz2kTripleGenerator<T>::Spdz2kTripleGenerator(const OTTripleSetup& setup,
const Names& names, int thread_num, int _nTriples, int nloops,
MascotParams& machine, Player* parentPlayer) :
MascotParams& machine, mac_key_type mac_key, Player* parentPlayer) :
NPartyTripleGenerator<T>(setup, names, thread_num, _nTriples, nloops,
machine, parentPlayer)
machine, mac_key, parentPlayer)
{
}
template<class T>
OTTripleGenerator<T>::OTTripleGenerator(OTTripleSetup& setup,
OTTripleGenerator<T>::OTTripleGenerator(const OTTripleSetup& setup,
const Names& names, int thread_num, int _nTriples, int nloops,
MascotParams& machine, Player* parentPlayer) :
MascotParams& machine, mac_key_type mac_key, Player* parentPlayer) :
globalPlayer(parentPlayer ? *parentPlayer : *new PlainPlayer(names,
- thread_num * names.num_players() * names.num_players())),
parentPlayer(parentPlayer),
thread_num(thread_num),
mac_key(mac_key),
my_num(setup.get_my_num()),
nloops(nloops),
nparties(setup.get_nparties()),
machine(machine)
machine(machine),
MC(0)
{
nTriplesPerLoop = DIV_CEIL(_nTriples, nloops);
nTriples = nTriplesPerLoop * nloops;
@@ -208,7 +197,6 @@ void NPartyTripleGenerator<W>::generateInputs(int player)
{
typedef open_type T;
auto& machine = this->machine;
auto& nTriplesPerLoop = this->nTriplesPerLoop;
auto& valueBits = this->valueBits;
auto& share_prg = this->share_prg;
@@ -235,7 +223,7 @@ void NPartyTripleGenerator<W>::generateInputs(int player)
GlobalPRNG G(globalPlayer);
Share<T> check_sum;
inputs.resize(toCheck);
auto mac_key = machine.template get_mac_key<mac_key_type>();
auto mac_key = this->get_mac_key();
SemiInput<SemiShare<T>> input(0, globalPlayer);
input.reset_all(globalPlayer);
vector<T> secrets(toCheck);
@@ -289,7 +277,7 @@ void MascotTripleGenerator<T>::generateBitsGf2n()
bits.resize(nBitsToCheck);
vector<T> to_open(1);
vector<typename T::clear> opened(1);
MAC_Check_<T> MC(this->machine.template get_mac_key<typename T::clear>());
MAC_Check_<T> MC(this->get_mac_key());
this->start_progress();
@@ -313,7 +301,7 @@ void MascotTripleGenerator<T>::generateBitsGf2n()
typename T::clear r;
for (int j = 0; j < nBitsToCheck; j++)
{
auto mac_sum = valueBits[0].get_bit(j) ? MC.get_alphai() : 0;
auto mac_sum = valueBits[0].get_bit(j) ? this->get_mac_key() : 0;
for (int i = 0; i < this->nparties-1; i++)
mac_sum += this->ot_multipliers[i]->macs[0][j];
bits[j].set_share(valueBits[0].get_bit(j));
@@ -352,6 +340,13 @@ void MascotTripleGenerator<Share<gfp1>>::generateBits()
generateTriples();
}
template<>
inline
void MascotTripleGenerator<Share<gfp3>>::generateBits()
{
generateTriples();
}
template<class T>
void Spdz2kTripleGenerator<T>::generateTriples()
{
@@ -360,7 +355,6 @@ void Spdz2kTripleGenerator<T>::generateTriples()
auto& uncheckedTriples = this->uncheckedTriples;
auto& timers = this->timers;
auto& machine = this->machine;
auto& nTriplesPerLoop = this->nTriplesPerLoop;
auto& valueBits = this->valueBits;
auto& share_prg = this->share_prg;
@@ -382,7 +376,7 @@ void Spdz2kTripleGenerator<T>::generateTriples()
vector< PlainTriple_<Z2<K + 2 * S>, Z2<K + S>, 2> > amplifiedTriples(nTriplesPerLoop);
uncheckedTriples.resize(nTriplesPerLoop);
MAC_Check_Z2k<Z2<K + 2 * S>, Z2<S>, Z2<K + S>, Share<Z2<K + 2 * S>> > MC(
machine.template get_mac_key<Z2<S> >());
this->get_mac_key());
this->start_progress();
@@ -455,7 +449,7 @@ void Spdz2kTripleGenerator<T>::generateTriples()
// get piggy-backed random value
Z2<K + 2 * S> r_share = b_padded_bits.get_ptr_to_byte(nTriplesPerLoop, Z2<K + 2 * S>::N_BYTES);
Z2<K + 2 * S> r_mac;
r_mac.mul(r_share, this->machine.template get_mac_key<Z2<S>>());
r_mac.mul(r_share, this->get_mac_key());
for (int i = 0; i < this->nparties-1; i++)
r_mac += (ot_multipliers[i])->macs.at(1).at(nTriplesPerLoop);
Share<Z2<K + 2 * S>> r;
@@ -563,16 +557,17 @@ void MascotTripleGenerator<U>::generateTriples()
valueBits[2*i].resize(field_size * nPreampTriplesPerLoop);
valueBits[1].resize(field_size * nTriplesPerLoop);
vector< PlainTriple<T,2> > amplifiedTriples;
MAC_Check<T> MC(machine.template get_mac_key<T>());
MAC_Check<T> MC(this->get_mac_key());
if (machine.amplify)
preampTriples.resize(nTriplesPerLoop);
if (machine.generateMACs)
{
amplifiedTriples.resize(nTriplesPerLoop);
uncheckedTriples.resize(nTriplesPerLoop);
}
uncheckedTriples.resize(nTriplesPerLoop);
this->start_progress();
for (int k = 0; k < nloops; k++)
@@ -581,10 +576,15 @@ void MascotTripleGenerator<U>::generateTriples()
if (machine.amplify)
{
octet seed[SEED_SIZE];
Create_Random_Seed(seed, globalPlayer, SEED_SIZE);
PRNG G;
G.SetSeed(seed);
if (machine.fiat_shamir and nparties == 2)
ot_multipliers[0]->otCorrelator.common_seed(G);
else
{
octet seed[SEED_SIZE];
Create_Random_Seed(seed, globalPlayer, SEED_SIZE);
G.SetSeed(seed);
}
for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++)
{
PlainTriple<T,2> triple;
@@ -598,12 +598,16 @@ void MascotTripleGenerator<U>::generateTriples()
triple.output(outputFile);
timers["Writing"].stop();
}
else
for (int i = 0; i < 3; i++)
uncheckedTriples[iTriple].byIndex(i, 0).set_share(triple.byIndex(i, 0));
}
if (machine.generateMACs)
{
for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++)
amplifiedTriples[iTriple].to(valueBits, iTriple);
amplifiedTriples[iTriple].to(valueBits, iTriple,
machine.check ? 2 : 1);
for (int i = 0; i < nparties-1; i++)
ot_multipliers[i]->inbox.push({});
@@ -625,7 +629,7 @@ void MascotTripleGenerator<U>::generateTriples()
if (machine.check)
{
sacrifice(uncheckedTriples, MC, G);
sacrifice(uncheckedTriples, this->MC ? *this->MC : MC, G);
}
}
}

View File

@@ -259,7 +259,7 @@ void naive_transpose64(vector<BitVector>& output, const vector<BitVector>& input
}
OTExtension::OTExtension(BaseOT& baseOT, TwoPartyPlayer* player,
OTExtension::OTExtension(const BaseOT& baseOT, TwoPartyPlayer* player,
bool passive) : player(player)
{
nbaseOTs = baseOT.nOT;

View File

@@ -30,7 +30,7 @@ public:
vector<BitVector> receiverOutput;
map<string,long long> times;
OTExtension(BaseOT& baseOT, TwoPartyPlayer* player, bool passive);
OTExtension(const BaseOT& baseOT, TwoPartyPlayer* player, bool passive);
OTExtension(int nbaseOTs, int baseLength,
int nloops, int nsubloops,

View File

@@ -308,7 +308,8 @@ void OTExtensionWithMatrix::hash_outputs(int nOTs)
}
template <class V>
void OTExtensionWithMatrix::hash_outputs(int nOTs, vector<V>& senderOutput, V& receiverOutput)
void OTExtensionWithMatrix::hash_outputs(int nOTs, vector<V>& senderOutput,
V& receiverOutput, bool correlated)
{
//cout << "Hashing... " << flush;
octetStream os, h_os(HASH_SIZE);
@@ -341,7 +342,11 @@ void OTExtensionWithMatrix::hash_outputs(int nOTs, vector<V>& senderOutput, V& r
for (int j = 0; j < 8; j++)
{
tmp[0][j] = senderOutputMatrices[0].squares[i_outer_input].rows[i_inner_input + j];
tmp[1][j] = tmp[0][j] ^ baseReceiverInput.get_int128(0);
if (correlated)
tmp[1][j] = tmp[0][j] ^ baseReceiverInput.get_int128(0);
else
tmp[1][j] =
senderOutputMatrices[1].squares[i_outer_input].rows[i_inner_input + j];
}
for (int j = 0; j < 2; j++)
mmo.hashBlocks<T, 8>(
@@ -366,17 +371,39 @@ void OTExtensionWithMatrix::hash_outputs(int nOTs, vector<V>& senderOutput, V& r
template <class U>
template <class T>
void OTCorrelator<U>::reduce_squares(unsigned int nTriples, vector<T>& output)
void OTCorrelator<U>::reduce_squares(unsigned int nTriples, vector<T>& output, int start)
{
if (receiverOutputMatrix.squares.size() < nTriples)
if (receiverOutputMatrix.squares.size() < nTriples + start)
throw invalid_length();
output.resize(nTriples);
for (unsigned int j = 0; j < nTriples; j++)
{
receiverOutputMatrix.squares[j].sub(senderOutputMatrices[0].squares[j]).to(output[j]);
receiverOutputMatrix.squares[j + start].sub(
senderOutputMatrices[0].squares[j + start]).to(output[j]);
}
}
template <class U>
void OTCorrelator<U>::common_seed(PRNG& G)
{
Slice<U> t1Slice(t1, 0, t1.squares.size());
Slice<U> uSlice(u, 0, u.squares.size());
octetStream os;
if (player->my_num())
{
t1Slice.pack(os);
uSlice.pack(os);
}
else
{
uSlice.pack(os);
t1Slice.pack(os);
}
auto hash = os.hash();
G = PRNG(hash);
}
octet* OTExtensionWithMatrix::get_receiver_output(int i)
{
return (octet*)&receiverOutputMatrix.squares[i/128].rows[i%128];
@@ -515,34 +542,35 @@ template class OTCorrelator<BitMatrix>;
#define Z(BM,GF) \
template class OTCorrelator<BM>; \
template void OTCorrelator<BM>::reduce_squares<GF>(unsigned int nTriples, \
vector<GF>& output);
vector<GF>& output, int);
#define ZZZZ(GF) \
template void OTExtensionWithMatrix::print_post_correlate<GF>( \
BitVector& newReceiverInput, int j, int offset, int sender); \
#define ZZZ(GF, M) Z(M, GF) \
template void OTExtensionWithMatrix::hash_outputs(int, vector<M >&, M&);
template void OTExtensionWithMatrix::hash_outputs(int, vector<M >&, M&, bool);
ZZZZ(gf2n_long)
ZZZ(gf2n_short, Matrix<gf2n_short_square>)
ZZZ(gf2n_long, Matrix<Square<gf2n_long>>)
ZZZ(gfp1, Matrix<Square<gfp1>>)
ZZZ(gfp3, Matrix<Square<gfp3>>)
ZZZ(BitVec, Matrix<BitDiagonal>)
#undef XX
#define XX(T,U,N,L) \
template class OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >; \
template void OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >::reduce_squares(unsigned int nTriples, \
vector<U>& output); \
vector<U>& output, int); \
template void OTExtensionWithMatrix::hash_outputs(int, \
std::vector<Matrix<Rectangle<Z2<N>, Z2<L> > >, std::allocator<Matrix<Rectangle<Z2<N>, Z2<L> > > > >&, \
Matrix<Rectangle<Z2<N>, Z2<L> > >&);
Matrix<Rectangle<Z2<N>, Z2<L> > >&, bool);
#undef X
#define X(N,L) \
template void OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >::reduce_squares(unsigned int nTriples, \
vector<Z2kRectangle<N, L> >& output); \
vector<Z2kRectangle<N, L> >& output, int); \
XX(Z2<L>,Z2<N>,N,L)
//X(96, 160)

View File

@@ -45,7 +45,9 @@ public:
U& baseReceiverOutput);
void correlate(int start, int slice, BitVector& newReceiverInput, bool useConstantBase, int repeat = 1);
template <class T>
void reduce_squares(unsigned int nTriples, vector<T>& output);
void reduce_squares(unsigned int nTriples, vector<T>& output,
int start = 0);
void common_seed(PRNG& G);
};
class OTExtensionWithMatrix : public OTCorrelator<BitMatrix>
@@ -80,7 +82,8 @@ public:
void transpose(int start, int slice);
void expand_transposed();
template <class V>
void hash_outputs(int nOTs, vector<V>& senderOutput, V& receiverOutput);
void hash_outputs(int nOTs, vector<V>& senderOutput, V& receiverOutput,
bool correlated = true);
void print(BitVector& newReceiverInput, int i = 0);
template <class T>

View File

@@ -7,18 +7,11 @@
#include "OT/OTMultiplier.h"
#include "OT/NPartyTripleGenerator.h"
#include "OT/Rectangle.h"
#include "Math/Z2k.h"
#include "Math/BitVec.h"
#include "Protocols/SemiShare.h"
#include "Protocols/Semi2kShare.h"
#include "Protocols/Spdz2kShare.h"
#include "OT/BaseOT.h"
#include "OT/OTVole.hpp"
#include "OT/Row.hpp"
#include "OT/Rectangle.hpp"
#include "Math/Z2k.hpp"
#include "Math/Square.hpp"
#include <math.h>
@@ -31,7 +24,8 @@ OTMultiplier<T>::OTMultiplier(OTTripleGenerator<T>& generator,
rot_ext(128, 128, 0, 1,
generator.players[thread_num], generator.baseReceiverInput,
generator.baseSenderInputs[thread_num],
generator.baseReceiverOutputs[thread_num], BOTH, !generator.machine.check),
generator.baseReceiverOutputs[thread_num], BOTH,
!generator.machine.correlation_check),
otCorrelator(0, 0, 0, 0, generator.players[thread_num], {}, {}, {}, BOTH, true)
{
this->thread = 0;
@@ -89,7 +83,7 @@ OTMultiplier<T>::~OTMultiplier()
template<class T>
void OTMultiplier<T>::multiply()
{
keyBits.set(generator.machine.template get_mac_key<typename T::mac_key_type>());
keyBits.set(generator.get_mac_key());
rot_ext.extend(keyBits.size(), keyBits);
this->outbox.push({});
senderOutput.resize(keyBits.size());
@@ -140,11 +134,6 @@ void OTMultiplier<W>::multiplyForTriples()
{
typedef typename W::Rectangle X;
// dummy input for OT correlator
vector<BitVector> _;
vector< vector<BitVector> > __;
BitVector ___;
otCorrelator.resize(X::N_COLUMNS * generator.nPreampTriplesPerLoop);
rot_ext.resize(X::N_ROWS * generator.nPreampTriplesPerLoop + 2 * 128);
@@ -161,8 +150,26 @@ void OTMultiplier<W>::multiplyForTriples()
this->inbox.pop(job);
BitVector aBits = generator.valueBits[0];
//timers["Extension"].start();
rot_ext.extend_correlated(aBits);
rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput);
if (generator.machine.use_extension)
{
rot_ext.extend_correlated(aBits);
}
else
{
BaseOT bot(aBits.size(), -1, generator.players[thread_num]);
bot.set_receiver_inputs(aBits);
bot.exec_base(false);
for (size_t i = 0; i < aBits.size(); i++)
{
rot_ext.receiverOutputMatrix[i] =
bot.receiver_outputs[i].get_int128(0).a;
for (int j = 0; j < 2; j++)
rot_ext.senderOutputMatrices[j][i] =
bot.sender_inputs[i][j].get_int128(0).a;
}
}
rot_ext.hash_outputs(aBits.size(), baseSenderOutputs,
baseReceiverOutput, generator.machine.use_extension);
//timers["Extension"].stop();
//timers["Correlation"].start();
@@ -215,8 +222,6 @@ void MascotMultiplier<U>::after_correlation()
{
typedef typename U::open_type T;
this->auth_ot_ext.resize(
this->generator.nPreampTriplesPerLoop * T::Square::N_COLUMNS);
this->auth_ot_ext.set_role(BOTH);
this->otCorrelator.reduce_squares(this->generator.nPreampTriplesPerLoop,
@@ -229,15 +234,45 @@ void MascotMultiplier<U>::after_correlation()
this->macs.resize(3);
MultJob job;
this->inbox.pop(job);
auto& generator = this->generator;
array<int, 3> n_vals;
for (int j = 0; j < 3; j++)
{
int nValues = this->generator.nTriplesPerLoop;
n_vals[j] = generator.nTriplesPerLoop;
if (this->generator.machine.check && (j % 2 == 0))
nValues *= 2;
this->auth_ot_ext.expand(0, nValues);
this->auth_ot_ext.correlate(0, nValues,
this->generator.valueBits[j], true);
this->auth_ot_ext.reduce_squares(nValues, this->macs[j]);
n_vals[j] *= 2;
}
if (generator.machine.fewer_rounds)
{
BitVector bits;
int total = 0;
for (int j = 0; j < 3; j++)
{
bits.append(generator.valueBits[j],
n_vals[j] * T::Square::N_COLUMNS);
total += n_vals[j];
}
this->auth_ot_ext.resize(bits.size());
this->auth_ot_ext.expand(0, total);
this->auth_ot_ext.correlate(0, total, bits, true);
total = 0;
for (int j = 0; j < 3; j++)
{
this->auth_ot_ext.reduce_squares(n_vals[j], this->macs[j], total);
total += n_vals[j];
}
}
else
{
this->auth_ot_ext.resize(n_vals[0] * T::Square::N_COLUMNS);
for (int j = 0; j < 3; j++)
{
int nValues = n_vals[j];
this->auth_ot_ext.expand(0, nValues);
this->auth_ot_ext.correlate(0, nValues,
this->generator.valueBits[j], true);
this->auth_ot_ext.reduce_squares(nValues, this->macs[j]);
}
}
this->outbox.push(job);
}

View File

@@ -41,3 +41,22 @@ void OTTripleSetup::close_connections()
delete players[i];
}
}
OTTripleSetup OTTripleSetup::get_fresh()
{
OTTripleSetup res = *this;
for (int i = 0; i < nparties - 1; i++)
{
BaseOT bot(nbase, 128, 0);
bot.sender_inputs = baseSenderInputs[i];
bot.receiver_outputs = baseReceiverOutputs[i];
bot.set_seeds();
bot.extend_length();
baseSenderInputs[i] = bot.sender_inputs;
baseReceiverOutputs[i] = bot.receiver_outputs;
bot.extend_length();
res.baseSenderInputs[i] = bot.sender_inputs;
res.baseReceiverOutputs[i] = bot.receiver_outputs;
}
return res;
}

View File

@@ -11,7 +11,7 @@
*/
class OTTripleSetup
{
vector<int> base_receiver_inputs;
BitVector base_receiver_inputs;
vector<BaseOT*> baseOTs;
PRNG G;
@@ -25,10 +25,10 @@ public:
vector< vector< vector<BitVector> > > baseSenderInputs;
vector< vector<BitVector> > baseReceiverOutputs;
int get_nparties() { return nparties; }
int get_nbase() { return nbase; }
int get_my_num() { return my_num; }
int get_base_receiver_input(int i) { return base_receiver_inputs[i]; }
int get_nparties() const { return nparties; }
int get_nbase() const { return nbase; }
int get_my_num() const { return my_num; }
int get_base_receiver_input(int i) const { return base_receiver_inputs[i]; }
OTTripleSetup(Player& N, bool real_OTs)
: nparties(N.num_players()), my_num(N.my_num()), nbase(128)
@@ -78,6 +78,8 @@ public:
//template <class T>
//T get_mac_key();
OTTripleSetup get_fresh();
};

View File

@@ -16,13 +16,16 @@ public:
T b;
T c[N];
int repeat(int l)
int repeat(int l, bool check)
{
switch (l)
{
case 0:
case 2:
return N;
if (check)
return N;
else
return 1;
case 1:
return 1;
default:
@@ -75,12 +78,12 @@ class PlainTriple : public Triple<T,N>
{
public:
// this assumes that valueBits[1] is still set to the bits of b
void to(vector<BitVector>& valueBits, int i)
void to(vector<BitVector>& valueBits, int i, int repeat = N)
{
for (int j = 0; j < N; j++)
{
valueBits[0].set_portion(i * N + j, this->a[j]);
valueBits[2].set_portion(i * N + j, this->c[j]);
valueBits[0].set_portion(i * repeat + j, this->a[j]);
valueBits[2].set_portion(i * repeat + j, this->c[j]);
}
}
};
@@ -123,12 +126,12 @@ public:
{
for (int l = 0; l < 3; l++)
{
int repeat = this->repeat(l);
int repeat = this->repeat(l, generator.machine.check);
for (int j = 0; j < repeat; j++)
{
T value = triple.byIndex(l,j);
T mac;
mac.mul(value, generator.machine.template get_mac_key<U>());
mac.mul(value, generator.get_mac_key());
for (int i = 0; i < generator.nparties-1; i++)
mac += generator.ot_multipliers[i]->macs.at(l).at(iTriple * repeat + j);
Share<T>& share = this->byIndex(l,j);

View File

@@ -16,28 +16,21 @@ class GeneratorThread;
class MascotParams : virtual public OfflineParams
{
protected:
gf2n_short mac_key2s;
gf2n_long mac_key2l;
gfp1 mac_keyp;
Z2<128> mac_keyz;
public:
string prep_data_dir;
bool generateMACs;
bool amplify;
bool check;
bool correlation_check;
bool generateBits;
bool use_extension;
bool fewer_rounds;
bool fiat_shamir;
struct timeval start, stop;
MascotParams();
void set_passive();
template <class T>
T get_mac_key();
template <class T>
void set_mac_key(T key);
};
class TripleMachine : public OfflineMachineBase, public MascotParams
@@ -45,6 +38,10 @@ class TripleMachine : public OfflineMachineBase, public MascotParams
Names N[2];
int nConnections;
gf2n mac_key2;
gfp1 mac_keyp;
Z2<128> mac_keyz;
public:
int nloops;
bool primeField;
@@ -54,7 +51,8 @@ public:
TripleMachine(int argc, const char** argv);
template<class T>
GeneratorThread* new_generator(OTTripleSetup& setup, int i);
GeneratorThread* new_generator(OTTripleSetup& setup, int i,
typename T::mac_key_type mac_key);
void run();

View File

@@ -30,7 +30,7 @@ public:
string progname;
int nthreads;
vector<vector<OTTripleSetup>> ot_setups;
vector<OTTripleSetup> ot_setups;
static BaseMachine& s();

View File

@@ -94,6 +94,7 @@ public:
virtual void purge() {}
virtual size_t data_sent() { return 0; }
virtual NamedCommStats comm_stats() { return {}; }
virtual void get_three_no_count(Dtype dtype, T& a, T& b, T& c) = 0;
virtual void get_two_no_count(Dtype dtype, T& a, T& b) = 0;
@@ -112,6 +113,7 @@ public:
virtual array<T, 3> get_triple(int n_bits);
virtual void buffer_triples() {}
virtual void buffer_inverses() {}
};
template<class T>

View File

@@ -13,6 +13,8 @@
#include "FixInput.h"
#include "FloatInput.h"
#include "IntInput.hpp"
template<class T>
InputBase<T>::InputBase(ArithmeticProcessor* proc) :
P(0), values_input(0)
@@ -295,7 +297,7 @@ void InputBase<T>::input_mixed(SubProcessor<T>& Proc, const vector<int>& args,
cout << "Please input " << U::NAME << "s:" << endl; \
prepare<U>(Proc, player, &args[i + U::N_DEST + 1], size); \
break;
X(IntInput) X(FixInput) X(FloatInput)
X(IntInput<typename T::clear>) X(FixInput) X(FloatInput)
#undef X
default:
throw runtime_error("unknown input type: " + to_string(type));
@@ -317,7 +319,7 @@ void InputBase<T>::input_mixed(SubProcessor<T>& Proc, const vector<int>& args,
n_arg_tuple = U::N_DEST + U::N_PARAM + 2; \
finalize<U>(Proc, args[i + n_arg_tuple - 1], &args[i + 1], size); \
break;
X(IntInput) X(FixInput) X(FloatInput)
X(IntInput<typename T::clear>) X(FixInput) X(FloatInput)
#undef X
default:
throw runtime_error("unknown input type: " + to_string(type));

View File

@@ -61,6 +61,9 @@ enum
USE_PREP = 0x1C,
STARTGRIND = 0x1D,
STOPGRIND = 0x1E,
NPLAYERS = 0xE2,
THRESHOLD = 0xE3,
PLAYERID = 0xE4,
// Addition
ADDC = 0x20,
ADDS = 0x21,

View File

@@ -177,6 +177,9 @@ void BaseInstruction::parse_operands(istream& s, int pos)
case PRINTCHRINT:
case PRINTSTRINT:
case PRINTINT:
case NPLAYERS:
case THRESHOLD:
case PLAYERID:
r[0]=get_int(s);
break;
// instructions with 3 registers + 1 integer operand
@@ -442,6 +445,9 @@ int BaseInstruction::get_reg_type() const
case CONVMODP:
case GCONVGF2N:
case RAND:
case NPLAYERS:
case THRESHOLD:
case PLAYERID:
return INT;
case PREP:
case USE_PREP:
@@ -1046,10 +1052,10 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Proc.temp.ans2.output(Proc.private_output, false);
break;
case INPUT:
sint::Input::template input<IntInput>(Proc.Procp, start, size);
sint::Input::template input<IntInput<typename sint::clear>>(Proc.Procp, start, size);
return;
case GINPUT:
sgf2n::Input::template input<IntInput>(Proc.Proc2, start, size);
sgf2n::Input::template input<IntInput<typename sgf2n::clear>>(Proc.Proc2, start, size);
return;
case INPUTFIX:
sint::Input::template input<FixInput>(Proc.Procp, start, size);
@@ -1404,6 +1410,15 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
case STOPGRIND:
CALLGRIND_STOP_INSTRUMENTATION;
break;
case NPLAYERS:
Proc.write_Ci(r[0], Proc.P.num_players());
break;
case THRESHOLD:
Proc.write_Ci(r[0], sint::threshold(Proc.P.num_players()));
break;
case PLAYERID:
Proc.write_Ci(r[0], Proc.P.my_num());
break;
// ***
// TODO: read/write shared GF(2^n) data instructions
// ***

View File

@@ -1,14 +0,0 @@
/*
* IntInput.cpp
*
*/
#include "IntInput.h"
const char* IntInput::NAME = "integer";
void IntInput::read(std::istream& in, const int* params)
{
(void) params;
in >> items[0];
}

View File

@@ -8,6 +8,7 @@
#include <iostream>
template<class T>
class IntInput
{
public:
@@ -17,7 +18,7 @@ public:
const static int TYPE = 0;
long items[N_DEST];
T items[N_DEST];
void read(std::istream& in, const int* params);
};

15
Processor/IntInput.hpp Normal file
View File

@@ -0,0 +1,15 @@
/*
* IntInput.cpp
*
*/
#include "IntInput.h"
template<class T>
const char* IntInput<T>::NAME = "integer";
template<class T>
void IntInput<T>::read(std::istream& in, const int*)
{
in >> items[0];
}

View File

@@ -127,10 +127,8 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
P = new CryptoPlayer(playerNames, 0xF000);
else
P = new PlainPlayer(playerNames, 0xF000);
ot_setups.resize(nthreads);
for (int i = 0; i < nthreads; i++)
for (int j = 0; j < 3; j++)
ot_setups.at(i).push_back({ *P, true });
ot_setups.push_back({ *P, true });
delete P;
}

View File

@@ -52,8 +52,6 @@ public:
Preprocessing<T>& DataF, Player& P);
// Access to PO (via calls to POpen start/stop)
void POpen_Start(const vector<int>& reg,const Player& P,int size);
void POpen_Stop(const vector<int>& reg,const Player& P,int size);
void POpen(const vector<int>& reg,const Player& P,int size);
void muls(const vector<int>& reg, int size);

View File

@@ -6,6 +6,7 @@
#include "Protocols/ReplicatedInput.hpp"
#include "Protocols/ReplicatedPrivateOutput.hpp"
#include "Processor/ProcessorBase.hpp"
#include <sodium.h>
#include <string>
@@ -406,15 +407,16 @@ void Processor<sint, sgf2n>::write_shares_to_file(const vector<int>& data_regist
}
template <class T>
void SubProcessor<T>::POpen_Start(const vector<int>& reg,const Player& P,int size)
void SubProcessor<T>::POpen(const vector<int>& reg,const Player& P,int size)
{
int sz=reg.size();
assert(reg.size() % 2 == 0);
int sz=reg.size() / 2;
Sh_PO.clear();
Sh_PO.reserve(sz*size);
if (size>1)
{
for (typename vector<int>::const_iterator reg_it=reg.begin();
reg_it!=reg.end(); reg_it++)
for (typename vector<int>::const_iterator reg_it=reg.begin() + 1;
reg_it < reg.end(); reg_it += 2)
{
auto begin=S.begin()+*reg_it;
Sh_PO.insert(Sh_PO.end(),begin,begin+size);
@@ -423,24 +425,15 @@ void SubProcessor<T>::POpen_Start(const vector<int>& reg,const Player& P,int siz
else
{
for (int i=0; i<sz; i++)
{ Sh_PO.push_back(S[reg[i]]); }
{ Sh_PO.push_back(S[reg[2 * i + 1]]); }
}
PO.resize(sz*size);
MC.POpen_Begin(PO,Sh_PO,P);
}
template <class T>
void SubProcessor<T>::POpen_Stop(const vector<int>& reg,const Player& P,int size)
{
int sz=reg.size();
PO.resize(sz*size);
MC.POpen_End(PO,Sh_PO,P);
MC.POpen(PO,Sh_PO,P);
if (size>1)
{
auto PO_it=PO.begin();
for (typename vector<int>::const_iterator reg_it=reg.begin();
reg_it!=reg.end(); reg_it++)
reg_it!=reg.end(); reg_it += 2)
{
for (auto C_it=C.begin()+*reg_it;
C_it!=C.begin()+*reg_it+size; C_it++)
@@ -452,36 +445,16 @@ void SubProcessor<T>::POpen_Stop(const vector<int>& reg,const Player& P,int size
}
else
{
for (unsigned int i=0; i<reg.size(); i++)
{ C[reg[i]] = PO[i]; }
for (unsigned int i = 0; i < reg.size() / 2; i++)
{
C[reg[2 * i]] = PO[i];
}
}
Proc.sent += reg.size() * size;
Proc.rounds++;
}
inline void unzip_open(vector<int>& dest, vector<int>& source, const vector<int>& reg)
{
int n = reg.size() / 2;
source.resize(n);
dest.resize(n);
for (int i = 0; i < n; i++)
{
source[i] = reg[2 * i + 1];
dest[i] = reg[2 * i];
}
}
template<class T>
void SubProcessor<T>::POpen(const vector<int>& reg, const Player& P,
int size)
{
vector<int> source, dest;
unzip_open(dest, source, reg);
POpen_Start(source, P, size);
POpen_Stop(dest, P, size);
}
template<class T>
void SubProcessor<T>::muls(const vector<int>& reg, int size)
{

View File

@@ -3,6 +3,9 @@
*
*/
#ifndef PROCESSOR_PROCESSORBASE_HPP_
#define PROCESSOR_PROCESSORBASE_HPP_
#include "ProcessorBase.h"
#include "IntInput.h"
#include "FixInput.h"
@@ -11,6 +14,7 @@
#include <iostream>
inline
void ProcessorBase::open_input_file(const string& name)
{
#ifdef DEBUG_FILES
@@ -20,6 +24,7 @@ void ProcessorBase::open_input_file(const string& name)
input_filename = name;
}
inline
void ProcessorBase::open_input_file(int my_num, int thread_num)
{
string input_file = "Player-Data/Input-P" + to_string(my_num) + "-" + to_string(thread_num);
@@ -54,6 +59,4 @@ T ProcessorBase::get_input(istream& input_file, const string& input_filename, co
return res;
}
template IntInput ProcessorBase::get_input(bool, const int*);
template FixInput ProcessorBase::get_input(bool, const int*);
template FloatInput ProcessorBase::get_input(bool, const int*);
#endif

View File

@@ -129,7 +129,7 @@ def expandAESKey(cipherKey, Nr = 10, Nb = 4, Nk = 4):
temp[2] = box.apply_sbox(temp[2])
temp[3] = box.apply_sbox(temp[3])
temp[0] = temp[0] + ApplyEmbedding(rcon[int(i/Nk)])
temp[0] = temp[0] + ApplyEmbedding(rcon[int(i//Nk)])
for j in range(4):
round_key[4 * i + j] = round_key[4 * (i - Nk) + j] + temp[j]
@@ -233,7 +233,7 @@ def inverseMod(val):
for idx in range(40):
if idx % 5 == 0:
bd_val[idx] = raw_bit_dec[idx / 5]
bd_val[idx] = raw_bit_dec[idx // 5]
bd_squared = bd_val
squared_index = 2

View File

@@ -118,10 +118,10 @@ def main():
return True
if n_rounds > 0:
print 'run %d rounds' % n_rounds
print('run %d rounds' % n_rounds)
for_range(n_rounds)(game_loop)
else:
print 'run forever'
print('run forever')
do_while(game_loop)
main()

View File

@@ -13,7 +13,7 @@ if len(program.args) > 2:
m = int(program.args[2])
pack = min(n, 50)
n = (n + pack - 1) / pack
n = (n + pack - 1) // pack
a = sbit(1)
b = sbit(1, n=pack)

View File

@@ -50,9 +50,9 @@ test_decryption = True
instructions_base.set_global_vector_size(n_parallel)
if use_mimc_prf:
execfile('./Programs/Source/prf_mimc.mpc')
exec(compile(__builtins__['open']('./Programs/Source/prf_mimc.mpc').read(), './Programs/Source/prf_mimc.mpc', 'exec'))
elif use_leg_prf:
execfile('./Programs/Source/prf_leg.mpc')
exec(compile(__builtins__['open']('./Programs/Source/prf_leg.mpc').read(), './Programs/Source/prf_leg.mpc', 'exec'))
class HMAC(object):
def __init__(self, _enc):
@@ -97,7 +97,7 @@ class NonceEncryptMAC(object):
def get_long_random(self, nbits):
""" Returns random cint() % 2^{nbits} """
result = cint(0)
for i in range(nbits / 30):
for i in range(nbits // 30):
result += cint(regint.get_random(30))
result <<= 30
@@ -178,7 +178,7 @@ def time_private_mac(n_total, n_parallel, nmessages):
# Benchmark n_total HtMAC's while executing in parallel n_parallel
start_timer(1)
@for_range(n_total / n_parallel)
@for_range(n_total // n_parallel)
def block(index):
# Re-use off-line data after n_parallel runs for benchmarking purposes.
# If real system-use need to initialize num_calls with a larger constant.

View File

@@ -8,7 +8,7 @@ ml.set_n_threads(8)
debug = False
if 'halfprec' in program.args:
print '8-bit precision after point'
print('8-bit precision after point')
sfix.set_precision(8, 31)
cfix.set_precision(8, 31)
else:
@@ -26,7 +26,7 @@ n_features = 12634
if len(program.args) > 2:
if 'bc' in program.args:
print 'Compiling for BC-TCGA'
print('Compiling for BC-TCGA')
n_examples = 472
n_normal = 49
n_features = 17814
@@ -41,7 +41,7 @@ try:
except:
pass
print 'Using %d threads' % ml.Layer.n_threads
print('Using %d threads' % ml.Layer.n_threads)
n_fold = 5
test_share = 1. / n_fold
@@ -63,8 +63,8 @@ else:
n_test = sum(n_tests)
indices = [regint.Array(n) for n in n_ex]
indices[0].assign(range(n_pos, n_pos + n_normal))
indices[1].assign(range(n_pos))
indices[0].assign(list(range(n_pos, n_pos + n_normal)))
indices[1].assign(list(range(n_pos)))
test = regint.Array(n_test)
@@ -97,7 +97,7 @@ for arg in program.args:
m = re.match('tol=(.*)', arg)
if m:
sgd.tol = float(m.group(1))
print 'Stop with tolerance', sgd.tol
print('Stop with tolerance', sgd.tol)
sum_acc = cfix.MemValue(0)

View File

@@ -6,8 +6,8 @@ sbitfix.set_precision(16, 32)
def test(a, b, value_type=None):
try:
b = int(round((b * (1 << a.f))))
if b < 0:
b += 2 ** sbitfix.k
if b < 0:
b += 2 ** sbitfix.k
a = a.v.reveal()
except AttributeError:
pass

View File

@@ -55,12 +55,12 @@ def f(_):
def thread():
i = get_arg()
n_per_thread = n_inputs / n_threads
n_per_thread = n_inputs // n_threads
if n_per_thread % 2 != 0:
raise Exception('Number of inputs must be divisible by 2')
start = i * n_per_thread
tuples = [bid_sort(bids[start+2*j], bids[start+2*j+1]) \
for j in range(n_per_thread / 2)]
for j in range(n_per_thread // 2)]
first, second = util.tree_reduce(first_and_second, tuples)
results[2*i], results[2*i+1] = first, second

View File

@@ -29,6 +29,8 @@ class Beaver : public ProtocolBase<T>
typename T::MAC_Check* MC;
public:
static const bool uses_triples = true;
Player& P;
Beaver(Player& P) : prep(0), MC(0), P(P) {}
@@ -39,6 +41,9 @@ public:
void exchange();
T finalize_mul(int n = -1);
void start_exchange();
void stop_exchange();
int get_n_relevant_players() { return P.num_players(); }
};

View File

@@ -50,6 +50,20 @@ void Beaver<T>::exchange()
triple = triples.begin();
}
template<class T>
void Beaver<T>::start_exchange()
{
MC->POpen_Begin(opened, shares, P);
}
template<class T>
void Beaver<T>::stop_exchange()
{
MC->POpen_End(opened, shares, P);
it = opened.begin();
triple = triples.begin();
}
template<class T>
T Beaver<T>::finalize_mul(int n)
{

39
Protocols/HemiPrep.h Normal file
View File

@@ -0,0 +1,39 @@
/*
* HemiPrep.h
*
*/
#ifndef PROTOCOLS_HEMIPREP_H_
#define PROTOCOLS_HEMIPREP_H_
#include "ReplicatedPrep.h"
#include "FHEOffline/Multiplier.h"
template<class T>
class HemiPrep : public SemiHonestRingPrep<T>
{
typedef typename T::clear::FD FD;
static PairwiseMachine* pairwise_machine;
static Lock lock;
vector<Multiplier<FD>*> multipliers;
SeededPRNG G;
map<string, Timer> timers;
public:
static void basic_setup(Player& P);
static void teardown();
HemiPrep(SubProcessor<T>* proc, DataPositions& usage) :
RingPrep<T>(proc, usage), SemiHonestRingPrep<T>(proc, usage)
{
}
void buffer_triples();
void buffer_inverses();
};
#endif /* PROTOCOLS_HEMIPREP_H_ */

Some files were not shown because too many files have changed in this diff Show More