mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-08 05:03:59 -05:00
Python 3, semi-honest computation using semi-homomorphic encryption.
This commit is contained in:
@@ -133,7 +133,7 @@ void RealGarbleWire<T>::input(party_id_t from, char input)
|
||||
protocol.init_mul(party.shared_proc);
|
||||
protocol.prepare_mul(mask, T(1, party.P->my_num(), party.mac_key) - mask);
|
||||
protocol.exchange();
|
||||
if (party.MC->POpen(protocol.finalize_mul(), *party.P) != 0)
|
||||
if (party.MC->open(protocol.finalize_mul(), *party.P) != 0)
|
||||
throw runtime_error("input mask not a bit");
|
||||
}
|
||||
#ifdef DEBUG_MASK
|
||||
@@ -168,7 +168,7 @@ void RealGarbleWire<T>::output()
|
||||
auto& party = RealProgramParty<T>::s();
|
||||
assert(party.MC != 0);
|
||||
assert(party.P != 0);
|
||||
auto m = party.MC->POpen(mask, *party.P);
|
||||
auto m = party.MC->open(mask, *party.P);
|
||||
party.output_masks.push_back(m.get_bit(0));
|
||||
party.taint();
|
||||
#ifdef DEBUG_MASK
|
||||
|
||||
@@ -90,7 +90,7 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
|
||||
{
|
||||
mac_key.randomize(prng);
|
||||
if (T::needs_ot)
|
||||
BaseMachine::s().ot_setups.push_back({{{*P, true}}});
|
||||
BaseMachine::s().ot_setups.push_back({*P, true});
|
||||
prep = Preprocessing<T>::get_live_prep(0, usage);
|
||||
}
|
||||
else
|
||||
|
||||
@@ -1,5 +1,11 @@
|
||||
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
|
||||
|
||||
## 0.1.3 (Nov 21, 2019)
|
||||
|
||||
- Python 3
|
||||
- Semi-honest computation based on semi-homomorphic encryption
|
||||
- Access to player information in high-level language
|
||||
|
||||
## 0.1.2 (Oct 11, 2019)
|
||||
|
||||
- Machine learning capabilities used for [MobileNets inference](https://eprint.iacr.org/2019/131) and the iDASH submission
|
||||
|
||||
@@ -5,6 +5,6 @@ class Program(object):
|
||||
types.program = self
|
||||
instructions.program = self
|
||||
self.curr_tape = None
|
||||
execfile(progname)
|
||||
exec(compile(open(progname).read(), progname, 'exec'))
|
||||
def malloc(self, *args):
|
||||
pass
|
||||
|
||||
@@ -5,6 +5,7 @@ from Compiler.exceptions import *
|
||||
from Compiler import util, oram, floatingpoint, library
|
||||
import Compiler.GC.instructions as inst
|
||||
import operator
|
||||
from functools import reduce
|
||||
|
||||
class bits(Tape.Register, _structure):
|
||||
n = 40
|
||||
@@ -82,7 +83,7 @@ class bits(Tape.Register, _structure):
|
||||
cls.load_inst[util.is_constant(address)](res, address)
|
||||
return res
|
||||
def store_in_mem(self, address):
|
||||
self.store_inst[isinstance(address, (int, long))](self, address)
|
||||
self.store_inst[isinstance(address, int)](self, address)
|
||||
def __init__(self, value=None, n=None, size=None):
|
||||
if size != 1 and size is not None:
|
||||
raise Exception('invalid size for bit type: %s' % size)
|
||||
@@ -92,11 +93,11 @@ class bits(Tape.Register, _structure):
|
||||
self.load_other(value)
|
||||
def set_length(self, n):
|
||||
if n > self.max_length:
|
||||
print self.max_length
|
||||
print(self.max_length)
|
||||
raise Exception('too long: %d' % n)
|
||||
self.n = n
|
||||
def load_other(self, other):
|
||||
if isinstance(other, (int, long)):
|
||||
if isinstance(other, int):
|
||||
self.set_length(self.n or util.int_len(other))
|
||||
self.load_int(other)
|
||||
elif isinstance(other, regint):
|
||||
@@ -115,6 +116,7 @@ class bits(Tape.Register, _structure):
|
||||
def __repr__(self):
|
||||
return '%s(%d/%d)' % \
|
||||
(super(bits, self).__repr__(), self.n, type(self).n)
|
||||
__str__ = __repr__
|
||||
|
||||
class cbits(bits):
|
||||
max_length = 64
|
||||
@@ -219,13 +221,13 @@ class sbits(bits):
|
||||
@classmethod
|
||||
def load_dynamic_mem(cls, address):
|
||||
res = cls()
|
||||
if isinstance(address, (long, int)):
|
||||
if isinstance(address, int):
|
||||
inst.ldmsd(res, address, cls.n)
|
||||
else:
|
||||
inst.ldmsdi(res, address, cls.n)
|
||||
return res
|
||||
def store_in_dynamic_mem(self, address):
|
||||
if isinstance(address, (long, int)):
|
||||
if isinstance(address, int):
|
||||
inst.stmsd(self, address)
|
||||
else:
|
||||
inst.stmsdi(self, cbits.conv(address))
|
||||
@@ -322,7 +324,7 @@ class sbits(bits):
|
||||
mul_bits = [self if b else zero for b in bits]
|
||||
return self.bit_compose(mul_bits)
|
||||
else:
|
||||
print self.n, other
|
||||
print(self.n, other)
|
||||
return NotImplemented
|
||||
def __lshift__(self, i):
|
||||
return self.bit_compose([sbit(0)] * i + self.bit_decompose()[:self.max_length-i])
|
||||
@@ -478,7 +480,7 @@ class bitsBlock(oram.Block):
|
||||
self.start_demux = oram.demux_list(self.start_bits)
|
||||
self.entries = [sbits.bit_compose(self.value_bits[i*length:][:length]) \
|
||||
for i in range(entries_per_block)]
|
||||
self.mul_entries = map(operator.mul, self.start_demux, self.entries)
|
||||
self.mul_entries = list(map(operator.mul, self.start_demux, self.entries))
|
||||
self.bits = sum(self.mul_entries).bit_decompose()
|
||||
self.mul_value = sbits.compose(self.mul_entries, sum(self.lengths))
|
||||
self.anti_value = self.mul_value + self.value
|
||||
@@ -662,6 +664,12 @@ class sbitfix(_fix):
|
||||
return super(sbitfix, self).__mul__(other)
|
||||
__rxor__ = __xor__
|
||||
__rmul__ = __mul__
|
||||
@staticmethod
|
||||
def multipliable(other, k, f):
|
||||
class cls(_fix):
|
||||
int_type = sbitint.get_type(k)
|
||||
cls.set_precision(f, k)
|
||||
return cls._new(cls.int_type(other), k, f)
|
||||
|
||||
sbitfix.set_precision(20, 41)
|
||||
|
||||
|
||||
@@ -1,15 +1,15 @@
|
||||
import compilerLib, program, instructions, types, library, floatingpoint
|
||||
import GC.types
|
||||
from . import compilerLib, program, instructions, types, library, floatingpoint
|
||||
from .GC import types as GC_types
|
||||
import inspect
|
||||
from config import *
|
||||
from compilerLib import run
|
||||
from .config import *
|
||||
from .compilerLib import run
|
||||
|
||||
|
||||
# add all instructions to the program VARS dictionary
|
||||
compilerLib.VARS = {}
|
||||
instr_classes = [t[1] for t in inspect.getmembers(instructions, inspect.isclass)]
|
||||
|
||||
for mod in (types, GC.types):
|
||||
for mod in (types, GC_types):
|
||||
instr_classes += [t[1] for t in inspect.getmembers(mod, inspect.isclass)\
|
||||
if t[1].__module__ == mod.__name__]
|
||||
|
||||
|
||||
@@ -10,16 +10,17 @@ import Compiler.program
|
||||
import heapq, itertools
|
||||
import operator
|
||||
import sys
|
||||
from functools import reduce
|
||||
|
||||
|
||||
class StraightlineAllocator:
|
||||
"""Allocate variables in a straightline program using n registers.
|
||||
It is based on the precondition that every register is only defined once."""
|
||||
def __init__(self, n):
|
||||
self.alloc = {}
|
||||
self.alloc = dict_by_id()
|
||||
self.usage = Compiler.program.RegType.create_dict(lambda: 0)
|
||||
self.defined = {}
|
||||
self.dealloc = set()
|
||||
self.defined = dict_by_id()
|
||||
self.dealloc = set_by_id()
|
||||
self.n = n
|
||||
|
||||
def alloc_reg(self, reg, free):
|
||||
@@ -77,8 +78,8 @@ class StraightlineAllocator:
|
||||
unused_regs.append(j)
|
||||
if unused_regs and len(unused_regs) == len(list(i.get_def())):
|
||||
# only report if all assigned registers are unused
|
||||
print "Register(s) %s never used, assigned by '%s' in %s" % \
|
||||
(unused_regs,i,format_trace(i.caller))
|
||||
print("Register(s) %s never used, assigned by '%s' in %s" % \
|
||||
(unused_regs,i,format_trace(i.caller)))
|
||||
|
||||
for j in i.get_used():
|
||||
self.alloc_reg(j, alloc_pool)
|
||||
@@ -86,7 +87,7 @@ class StraightlineAllocator:
|
||||
self.dealloc_reg(j, i, alloc_pool)
|
||||
|
||||
if k % 1000000 == 0 and k > 0:
|
||||
print "Allocated registers for %d instructions at" % k, time.asctime()
|
||||
print("Allocated registers for %d instructions at" % k, time.asctime())
|
||||
|
||||
# print "Successfully allocated registers"
|
||||
# print "modp usage: %d clear, %d secret" % \
|
||||
@@ -97,8 +98,8 @@ class StraightlineAllocator:
|
||||
|
||||
|
||||
def determine_scope(block, options):
|
||||
last_def = defaultdict(lambda: -1)
|
||||
used_from_scope = set()
|
||||
last_def = defaultdict_by_id(lambda: -1)
|
||||
used_from_scope = set_by_id()
|
||||
|
||||
def find_in_scope(reg, scope):
|
||||
while True:
|
||||
@@ -114,18 +115,18 @@ def determine_scope(block, options):
|
||||
used_from_scope.add(reg)
|
||||
reg.can_eliminate = False
|
||||
else:
|
||||
print 'Warning: read before write at register', reg
|
||||
print '\tline %d: %s' % (n, instr)
|
||||
print '\tinstruction trace: %s' % format_trace(instr.caller, '\t\t')
|
||||
print '\tregister trace: %s' % format_trace(reg.caller, '\t\t')
|
||||
print('Warning: read before write at register', reg)
|
||||
print('\tline %d: %s' % (n, instr))
|
||||
print('\tinstruction trace: %s' % format_trace(instr.caller, '\t\t'))
|
||||
print('\tregister trace: %s' % format_trace(reg.caller, '\t\t'))
|
||||
if options.stop:
|
||||
sys.exit(1)
|
||||
|
||||
def write(reg, n):
|
||||
if last_def[reg] != -1:
|
||||
print 'Warning: double write at register', reg
|
||||
print '\tline %d: %s' % (n, instr)
|
||||
print '\ttrace: %s' % format_trace(instr.caller, '\t\t')
|
||||
print('Warning: double write at register', reg)
|
||||
print('\tline %d: %s' % (n, instr))
|
||||
print('\ttrace: %s' % format_trace(instr.caller, '\t\t'))
|
||||
if options.stop:
|
||||
sys.exit(1)
|
||||
last_def[reg] = n
|
||||
@@ -146,7 +147,7 @@ def determine_scope(block, options):
|
||||
write(reg, n)
|
||||
|
||||
block.used_from_scope = used_from_scope
|
||||
block.defined_registers = set(last_def.iterkeys())
|
||||
block.defined_registers = set_by_id(last_def.keys())
|
||||
|
||||
class Merger:
|
||||
def __init__(self, block, options, merge_classes):
|
||||
@@ -178,7 +179,7 @@ class Merger:
|
||||
if inst.is_vec():
|
||||
for arg in inst.args:
|
||||
arg.create_vector_elements()
|
||||
res = sum(zip(*inst.args), ())
|
||||
res = sum(list(zip(*inst.args)), ())
|
||||
return list(res)
|
||||
else:
|
||||
return inst.args
|
||||
@@ -241,7 +242,7 @@ class Merger:
|
||||
remaining_input_nodes = []
|
||||
def do_merge(nodes):
|
||||
if len(nodes) > 1000:
|
||||
print 'Merging %d inputs...' % len(nodes)
|
||||
print('Merging %d inputs...' % len(nodes))
|
||||
self.do_merge(iter(nodes))
|
||||
for n in self.input_nodes:
|
||||
inst = self.instructions[n]
|
||||
@@ -252,7 +253,7 @@ class Merger:
|
||||
if len(merge) >= self.max_parallel_open:
|
||||
do_merge(merge)
|
||||
merge[:] = []
|
||||
for merge in reversed(sorted(merges.itervalues())):
|
||||
for merge in reversed(sorted(merges.values())):
|
||||
if merge:
|
||||
do_merge(merge)
|
||||
self.input_nodes = remaining_input_nodes
|
||||
@@ -266,7 +267,7 @@ class Merger:
|
||||
instructions = self.instructions
|
||||
flex_nodes = defaultdict(dict)
|
||||
starters = []
|
||||
for n in xrange(len(G)):
|
||||
for n in range(len(G)):
|
||||
if n not in merge_nodes_set and \
|
||||
depth_of[n] != rev_depth_of[n] and G[n] and G.get_attr(n,'start') == -1 and not isinstance(instructions[n], AsymmetricCommunicationInstruction):
|
||||
#print n, depth_of[n], rev_depth_of[n]
|
||||
@@ -275,19 +276,19 @@ class Merger:
|
||||
not isinstance(self.instructions[n], RawInputInstruction):
|
||||
starters.append(n)
|
||||
if n % 10000000 == 0 and n > 0:
|
||||
print "Processed %d nodes at" % n, time.asctime()
|
||||
print("Processed %d nodes at" % n, time.asctime())
|
||||
|
||||
inputs = defaultdict(list)
|
||||
for node in self.input_nodes:
|
||||
player = self.instructions[node].args[0]
|
||||
inputs[player].append(node)
|
||||
first_inputs = [l[0] for l in inputs.itervalues()]
|
||||
first_inputs = [l[0] for l in inputs.values()]
|
||||
other_inputs = []
|
||||
i = 0
|
||||
while True:
|
||||
i += 1
|
||||
found = False
|
||||
for l in inputs.itervalues():
|
||||
for l in inputs.values():
|
||||
if i < len(l):
|
||||
other_inputs.append(l[i])
|
||||
found = True
|
||||
@@ -299,20 +300,20 @@ class Merger:
|
||||
# magical preorder for topological search
|
||||
max_depth = max(merges)
|
||||
if max_depth > 10000:
|
||||
print "Computing pre-ordering ..."
|
||||
for i in xrange(max_depth, 0, -1):
|
||||
print("Computing pre-ordering ...")
|
||||
for i in range(max_depth, 0, -1):
|
||||
preorder.append(G.get_attr(merges[i], 'stop'))
|
||||
for j in flex_nodes[i-1].itervalues():
|
||||
for j in flex_nodes[i-1].values():
|
||||
preorder.extend(j)
|
||||
preorder.extend(flex_nodes[0].get(i, []))
|
||||
preorder.append(merges[i])
|
||||
if i % 100000 == 0 and i > 0:
|
||||
print "Done level %d at" % i, time.asctime()
|
||||
print("Done level %d at" % i, time.asctime())
|
||||
preorder.extend(other_inputs)
|
||||
preorder.extend(starters)
|
||||
preorder.extend(first_inputs)
|
||||
if max_depth > 10000:
|
||||
print "Done at", time.asctime()
|
||||
print("Done at", time.asctime())
|
||||
return preorder
|
||||
|
||||
def longest_paths_merge(self):
|
||||
@@ -343,8 +344,8 @@ class Merger:
|
||||
t = type(self.instructions[merge[0]])
|
||||
self.counter[t] += len(merge)
|
||||
if len(merge) > 1000:
|
||||
print 'Merging %d %s in round %d/%d' % \
|
||||
(len(merge), t.__name__, i, len(merges))
|
||||
print('Merging %d %s in round %d/%d' % \
|
||||
(len(merge), t.__name__, i, len(merges)))
|
||||
self.do_merge(merge)
|
||||
|
||||
self.merge_inputs()
|
||||
@@ -352,11 +353,11 @@ class Merger:
|
||||
preorder = None
|
||||
|
||||
if len(instructions) > 100000:
|
||||
print "Topological sort ..."
|
||||
print("Topological sort ...")
|
||||
order = Compiler.graph.topological_sort(G, preorder)
|
||||
instructions[:] = [instructions[i] for i in order if instructions[i] is not None]
|
||||
if len(instructions) > 100000:
|
||||
print "Done at", time.asctime()
|
||||
print("Done at", time.asctime())
|
||||
|
||||
return len(merges)
|
||||
|
||||
@@ -377,7 +378,7 @@ class Merger:
|
||||
self.G = G
|
||||
|
||||
reg_nodes = {}
|
||||
last_def = defaultdict(lambda: -1)
|
||||
last_def = defaultdict_by_id(lambda: -1)
|
||||
last_mem_write = []
|
||||
last_mem_read = []
|
||||
warned_about_mem = []
|
||||
@@ -411,8 +412,8 @@ class Merger:
|
||||
|
||||
def handle_mem_access(addr, reg_type, last_access_this_kind,
|
||||
last_access_other_kind):
|
||||
this = last_access_this_kind[addr,reg_type]
|
||||
other = last_access_other_kind[addr,reg_type]
|
||||
this = last_access_this_kind[str(addr),reg_type]
|
||||
other = last_access_other_kind[str(addr),reg_type]
|
||||
if this and other:
|
||||
if this[-1] < other[0]:
|
||||
del this[:]
|
||||
@@ -429,15 +430,15 @@ class Merger:
|
||||
handle_mem_access(addr_i, reg_type, last_access_this_kind,
|
||||
last_access_other_kind)
|
||||
if not warned_about_mem and (instr.get_size() > 100):
|
||||
print 'WARNING: Order of memory instructions ' \
|
||||
'not preserved due to long vector, errors possible'
|
||||
print('WARNING: Order of memory instructions ' \
|
||||
'not preserved due to long vector, errors possible')
|
||||
warned_about_mem.append(True)
|
||||
else:
|
||||
handle_mem_access(addr, reg_type, last_access_this_kind,
|
||||
last_access_other_kind)
|
||||
if not warned_about_mem and not isinstance(instr, DirectMemoryInstruction):
|
||||
print 'WARNING: Order of memory instructions ' \
|
||||
'not preserved, errors possible'
|
||||
print('WARNING: Order of memory instructions ' \
|
||||
'not preserved, errors possible')
|
||||
# hack
|
||||
warned_about_mem.append(True)
|
||||
|
||||
@@ -553,11 +554,11 @@ class Merger:
|
||||
self.sources.append(n)
|
||||
|
||||
if n % 100000 == 0 and n > 0:
|
||||
print "Processed dependency of %d/%d instructions at" % \
|
||||
(n, len(block.instructions)), time.asctime()
|
||||
print("Processed dependency of %d/%d instructions at" % \
|
||||
(n, len(block.instructions)), time.asctime())
|
||||
|
||||
if len(open_nodes) > 1000:
|
||||
print "Program has %d %s instructions" % (len(open_nodes), merge_classes)
|
||||
print("Program has %d %s instructions" % (len(open_nodes), merge_classes))
|
||||
|
||||
def merge_nodes(self, i, j):
|
||||
""" Merge node j into i, removing node j """
|
||||
@@ -566,8 +567,8 @@ class Merger:
|
||||
G.remove_edge(i, j)
|
||||
if i in G[j]:
|
||||
G.remove_edge(j, i)
|
||||
G.add_edges_from(zip(itertools.cycle([i]), G[j], [G.weights[(j,k)] for k in G[j]]))
|
||||
G.add_edges_from(zip(G.pred[j], itertools.cycle([i]), [G.weights[(k,j)] for k in G.pred[j]]))
|
||||
G.add_edges_from(list(zip(itertools.cycle([i]), G[j], [G.weights[(j,k)] for k in G[j]])))
|
||||
G.add_edges_from(list(zip(G.pred[j], itertools.cycle([i]), [G.weights[(k,j)] for k in G.pred[j]])))
|
||||
G.get_attr(i, 'merges').append(j)
|
||||
G.remove_node(j)
|
||||
|
||||
@@ -578,7 +579,7 @@ class Merger:
|
||||
count = 0
|
||||
open_count = 0
|
||||
stats = defaultdict(lambda: 0)
|
||||
for i,inst in zip(xrange(len(instructions) - 1, -1, -1), reversed(instructions)):
|
||||
for i,inst in zip(range(len(instructions) - 1, -1, -1), reversed(instructions)):
|
||||
# remove if instruction has result that isn't used
|
||||
unused_result = not G.degree(i) and len(list(inst.get_def())) \
|
||||
and reduce(operator.and_, (reg.can_eliminate for reg in inst.get_def())) \
|
||||
@@ -608,21 +609,21 @@ class Merger:
|
||||
eliminate(i)
|
||||
count += 2
|
||||
if count > 0:
|
||||
print 'Eliminated %d dead instructions, among which %d opens: %s' \
|
||||
% (count, open_count, dict(stats))
|
||||
print('Eliminated %d dead instructions, among which %d opens: %s' \
|
||||
% (count, open_count, dict(stats)))
|
||||
|
||||
def print_graph(self, filename):
|
||||
f = open(filename, 'w')
|
||||
print >>f, 'digraph G {'
|
||||
print('digraph G {', file=f)
|
||||
for i in range(self.G.n):
|
||||
for j in self.G[i]:
|
||||
print >>f, '"%d: %s" -> "%d: %s";' % \
|
||||
(i, self.instructions[i], j, self.instructions[j])
|
||||
print >>f, '}'
|
||||
print('"%d: %s" -> "%d: %s";' % \
|
||||
(i, self.instructions[i], j, self.instructions[j]), file=f)
|
||||
print('}', file=f)
|
||||
f.close()
|
||||
|
||||
def print_depth(self, filename):
|
||||
f = open(filename, 'w')
|
||||
for i in range(self.G.n):
|
||||
print >>f, '%d: %s' % (self.depths[i], self.instructions[i])
|
||||
print('%d: %s' % (self.depths[i], self.instructions[i]), file=f)
|
||||
f.close()
|
||||
|
||||
@@ -26,7 +26,7 @@ def find_deeper(a, b, path, start, length, compute_level=True):
|
||||
any_empty = OR(a.empty, b.empty)
|
||||
a_diff = [XOR(a_bits[i], path_bits[i]) for i in range(start, length)]
|
||||
b_diff = [XOR(b_bits[i], path_bits[i]) for i in range(start, length)]
|
||||
diff = [XOR(ab, bb) for ab,bb in zip(a_bits, b_bits)[start:length]]
|
||||
diff = [XOR(ab, bb) for ab,bb in list(zip(a_bits, b_bits))[start:length]]
|
||||
diff_preor = type(a.value).PreOR([any_empty] + diff)
|
||||
diff_first = [x - y for x,y in zip(diff_preor, diff_preor[1:])]
|
||||
winner = sum((ad * df for ad,df in zip(a_diff, diff_first)), a.empty)
|
||||
@@ -38,7 +38,7 @@ def find_deeper(a, b, path, start, length, compute_level=True):
|
||||
def find_deepest(paths, search_path, start, length, compute_level=True):
|
||||
if len(paths) == 1:
|
||||
return None, paths[0], 1
|
||||
l = len(paths) / 2
|
||||
l = len(paths) // 2
|
||||
_, a, a_index = find_deepest(paths[:l], search_path, start, length, False)
|
||||
_, b, b_index = find_deepest(paths[l:], search_path, start, length, False)
|
||||
level, winner = find_deeper(a, b, search_path, start, length, compute_level)
|
||||
@@ -57,7 +57,7 @@ def greater_unary(a, b):
|
||||
if len(a) == 1:
|
||||
return a[0], b[0]
|
||||
else:
|
||||
l = len(a) / 2
|
||||
l = len(a) // 2
|
||||
return gu_step(greater_unary(a[l:], b[l:]), greater_unary(a[:l], b[:l]))
|
||||
|
||||
def comp_step(high, low):
|
||||
@@ -75,7 +75,7 @@ def comp_binary(a, b):
|
||||
if len(a) == 1:
|
||||
return a[0], b[0]
|
||||
else:
|
||||
l = len(a) / 2
|
||||
l = len(a) // 2
|
||||
return comp_step(comp_binary(a[l:], b[l:]), comp_binary(a[:l], b[:l]))
|
||||
|
||||
def unary_to_binary(l):
|
||||
@@ -89,8 +89,8 @@ class CircuitORAM(PathORAM):
|
||||
self.D = log2(size)
|
||||
self.logD = log2(self.D)
|
||||
self.L = self.D + 1
|
||||
print 'create oram of size %d with depth %d and %d buckets' \
|
||||
% (size, self.D, self.n_buckets())
|
||||
print('create oram of size %d with depth %d and %d buckets' \
|
||||
% (size, self.D, self.n_buckets()))
|
||||
self.value_type = value_type
|
||||
self.index_type = value_type.get_type(self.D)
|
||||
if entry_size is not None:
|
||||
@@ -245,7 +245,7 @@ class CircuitORAM(PathORAM):
|
||||
for i,_ in enumerate(self.recursive_evict_rounds()):
|
||||
Program.prog.curr_tape.start_new_basicblock(name='circuit-recursive-evict-round-%d-%d' % (i, self.size))
|
||||
def recursive_evict_rounds(self):
|
||||
for _ in itertools.izip(self.evict_rounds(), self.index.l.recursive_evict_rounds()):
|
||||
for _ in zip(self.evict_rounds(), self.index.l.recursive_evict_rounds()):
|
||||
yield
|
||||
def bucket_indices_on_path_to(self, leaf):
|
||||
# root is at 1, different to PathORAM
|
||||
@@ -272,10 +272,10 @@ threshold = 2**10
|
||||
|
||||
def OptimalCircuitORAM(size, value_type, *args, **kwargs):
|
||||
if size <= threshold:
|
||||
print size, 'below threshold', threshold
|
||||
print(size, 'below threshold', threshold)
|
||||
return LinearORAM(size, value_type, *args, **kwargs)
|
||||
else:
|
||||
print size, 'above threshold', threshold
|
||||
print(size, 'above threshold', threshold)
|
||||
return RecursiveCircuitORAM(size, value_type, *args, **kwargs)
|
||||
|
||||
class RecursiveCircuitIndexStructure(PackedIndexStructure):
|
||||
|
||||
@@ -28,8 +28,8 @@ use_inv = True
|
||||
# (r[i], r[i]^-1, r[i] * r[i-1]^-1)
|
||||
do_precomp = True
|
||||
|
||||
import instructions_base
|
||||
import util
|
||||
from . import instructions_base
|
||||
from . import util
|
||||
|
||||
def set_variant(options):
|
||||
""" Set flags based on the command-line option provided """
|
||||
@@ -55,7 +55,7 @@ def ld2i(c, n):
|
||||
""" Load immediate 2^n into clear GF(p) register c """
|
||||
t1 = program.curr_block.new_reg('c')
|
||||
ldi(t1, 2 ** (n % 30))
|
||||
for i in range(n / 30):
|
||||
for i in range(n // 30):
|
||||
t2 = program.curr_block.new_reg('c')
|
||||
mulci(t2, t1, 2 ** 30)
|
||||
t1 = t2
|
||||
@@ -75,13 +75,13 @@ def LTZ(s, a, k, kappa):
|
||||
|
||||
k: bit length of a
|
||||
"""
|
||||
from types import sint
|
||||
from .types import sint
|
||||
t = sint()
|
||||
Trunc(t, a, k, k - 1, kappa, True)
|
||||
subsfi(s, t, 0)
|
||||
|
||||
def LessThanZero(a, k, kappa):
|
||||
import types
|
||||
from . import types
|
||||
res = types.sint()
|
||||
LTZ(res, a, k, kappa)
|
||||
return res
|
||||
@@ -124,7 +124,7 @@ def TruncZeroes(a, k, m, signed):
|
||||
if program.options.ring:
|
||||
return TruncLeakyInRing(a, k, m, signed)
|
||||
else:
|
||||
import types
|
||||
from . import types
|
||||
tmp = types.cint()
|
||||
inv2m(tmp, m)
|
||||
return a * tmp
|
||||
@@ -136,7 +136,7 @@ def TruncLeakyInRing(a, k, m, signed):
|
||||
"""
|
||||
assert k > m
|
||||
assert int(program.options.ring) >= k
|
||||
from types import sint, intbitint, cint, cgf2n
|
||||
from .types import sint, intbitint, cint, cgf2n
|
||||
n_bits = k - m
|
||||
n_shift = int(program.options.ring) - n_bits
|
||||
r_bits = [sint.get_random_bit() for i in range(n_bits)]
|
||||
@@ -165,7 +165,7 @@ def TruncRoundNearest(a, k, m, kappa, signed=False):
|
||||
# cannot work with bit length k+1
|
||||
tmp = TruncRing(None, a, k, m - 1, signed)
|
||||
return TruncRing(None, tmp + 1, k - m + 1, 1, signed)
|
||||
from types import sint
|
||||
from .types import sint
|
||||
res = sint()
|
||||
Trunc(res, a + (1 << (m - 1)), k + 1, m, kappa, signed)
|
||||
return res
|
||||
@@ -277,7 +277,7 @@ def BitLTC1(u, a, b, kappa):
|
||||
"""
|
||||
k = len(b)
|
||||
p = [program.curr_block.new_reg('s') for i in range(k)]
|
||||
import floatingpoint
|
||||
from . import floatingpoint
|
||||
a_bits = floatingpoint.bits(a, k)
|
||||
if instructions_base.get_global_vector_size() == 1:
|
||||
a_ = a_bits
|
||||
@@ -357,12 +357,12 @@ def CarryOutAux(d, a, kappa):
|
||||
if k > 1 and k % 2 == 1:
|
||||
a.append(None)
|
||||
k += 1
|
||||
u = [None]*(k/2)
|
||||
u = [None]*(k//2)
|
||||
a = a[::-1]
|
||||
if k > 1:
|
||||
for i in range(k/2):
|
||||
u[i] = carry(a[2*i+1], a[2*i], i != k/2-1)
|
||||
CarryOutAux(d, u[:k/2][::-1], kappa)
|
||||
for i in range(k//2):
|
||||
u[i] = carry(a[2*i+1], a[2*i], i != k//2-1)
|
||||
CarryOutAux(d, u[:k//2][::-1], kappa)
|
||||
else:
|
||||
movs(d, a[0][1])
|
||||
|
||||
@@ -376,7 +376,7 @@ def CarryOut(res, a, b, c=0, kappa=None):
|
||||
c: initial carry-in bit
|
||||
"""
|
||||
k = len(a)
|
||||
import types
|
||||
from . import types
|
||||
d = [program.curr_block.new_reg('s') for i in range(k)]
|
||||
t = [[types.sint() for i in range(k)] for i in range(4)]
|
||||
s = [program.curr_block.new_reg('s') for i in range(3)]
|
||||
@@ -394,7 +394,7 @@ def CarryOut(res, a, b, c=0, kappa=None):
|
||||
|
||||
def CarryOutLE(a, b, c=0):
|
||||
""" Little-endian version """
|
||||
import types
|
||||
from . import types
|
||||
res = types.sint()
|
||||
CarryOut(res, a[::-1], b[::-1], c)
|
||||
return res
|
||||
@@ -407,7 +407,7 @@ def BitLTL(res, a, b, kappa):
|
||||
b: array of secret bits (same length as a)
|
||||
"""
|
||||
k = len(b)
|
||||
import floatingpoint
|
||||
from . import floatingpoint
|
||||
a_bits = floatingpoint.bits(a, k)
|
||||
s = [[program.curr_block.new_reg('s') for i in range(k)] for j in range(2)]
|
||||
t = [program.curr_block.new_reg('s') for i in range(1)]
|
||||
@@ -547,7 +547,7 @@ def KMulC(a):
|
||||
"""
|
||||
Return just the product of all items in a
|
||||
"""
|
||||
from types import sint, cint
|
||||
from .types import sint, cint
|
||||
p = sint()
|
||||
if use_inv:
|
||||
PreMulC_with_inverses(p, a)
|
||||
@@ -582,7 +582,7 @@ def Mod2(a_0, a, k, kappa, signed):
|
||||
adds(t[2], t[0], t[1])
|
||||
adds(t[3], t[2], r_prime)
|
||||
asm_open(c, t[3])
|
||||
import floatingpoint
|
||||
from . import floatingpoint
|
||||
c_0 = floatingpoint.bits(c, 1)[0]
|
||||
mulci(tc, c_0, 2)
|
||||
mulm(t[4], r_0, tc)
|
||||
@@ -591,4 +591,4 @@ def Mod2(a_0, a, k, kappa, signed):
|
||||
|
||||
|
||||
# hack for circular dependency
|
||||
from instructions import *
|
||||
from .instructions import *
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
from Compiler.program import Program
|
||||
from Compiler.config import *
|
||||
from Compiler.exceptions import *
|
||||
import instructions, instructions_base, types, comparison, library
|
||||
import GC.types
|
||||
from . import instructions, instructions_base, types, comparison, library
|
||||
from .GC import types as GC_types
|
||||
|
||||
import random
|
||||
import time
|
||||
@@ -25,36 +25,36 @@ def run(args, options, param=-1, merge_opens=True, emulate=True, \
|
||||
prog.DEBUG = debug
|
||||
VARS['program'] = prog
|
||||
if options.binary:
|
||||
VARS['sint'] = GC.types.sbitint.get_type(int(options.binary))
|
||||
VARS['sfix'] = GC.types.sbitfix
|
||||
VARS['sint'] = GC_types.sbitint.get_type(int(options.binary))
|
||||
VARS['sfix'] = GC_types.sbitfix
|
||||
comparison.set_variant(options)
|
||||
|
||||
print 'Compiling file', prog.infile
|
||||
print('Compiling file', prog.infile)
|
||||
|
||||
# no longer needed, but may want to support assembly in future (?)
|
||||
if assemblymode:
|
||||
prog.restart_main_thread()
|
||||
for i in xrange(INIT_REG_MAX):
|
||||
for i in range(INIT_REG_MAX):
|
||||
VARS['c%d'%i] = prog.curr_block.new_reg('c')
|
||||
VARS['s%d'%i] = prog.curr_block.new_reg('s')
|
||||
VARS['cg%d'%i] = prog.curr_block.new_reg('cg')
|
||||
VARS['sg%d'%i] = prog.curr_block.new_reg('sg')
|
||||
if i % 10000000 == 0 and i > 0:
|
||||
print "Initialized %d register variables at" % i, time.asctime()
|
||||
print("Initialized %d register variables at" % i, time.asctime())
|
||||
|
||||
# first pass determines how many assembler registers are used
|
||||
prog.FIRST_PASS = True
|
||||
execfile(prog.infile, VARS)
|
||||
exec(compile(open(prog.infile).read(), prog.infile, 'exec'), VARS)
|
||||
|
||||
if instructions_base.Instruction.count != 0:
|
||||
print 'instructions count', instructions_base.Instruction.count
|
||||
print('instructions count', instructions_base.Instruction.count)
|
||||
instructions_base.Instruction.count = 0
|
||||
prog.FIRST_PASS = False
|
||||
prog.reset_values()
|
||||
# make compiler modules directly accessible
|
||||
sys.path.insert(0, 'Compiler')
|
||||
# create the tapes
|
||||
execfile(prog.infile, VARS)
|
||||
exec(compile(open(prog.infile).read(), prog.infile, 'exec'), VARS)
|
||||
|
||||
# optimize the tapes
|
||||
for tape in prog.tapes:
|
||||
@@ -66,14 +66,14 @@ def run(args, options, param=-1, merge_opens=True, emulate=True, \
|
||||
sharedmem = list(prog.mem_s)
|
||||
prog.emulate()
|
||||
if prog.mem_c != clearmem or prog.mem_s != sharedmem:
|
||||
print 'Warning: emulated memory values changed after compiler optimization'
|
||||
print('Warning: emulated memory values changed after compiler optimization')
|
||||
# raise CompilerError('Compiler optimization caused incorrect memory write.')
|
||||
|
||||
if prog.main_thread_running:
|
||||
prog.update_req(prog.curr_tape)
|
||||
print 'Program requires:', repr(prog.req_num)
|
||||
print 'Cost:', 0 if prog.req_num is None else prog.req_num.cost()
|
||||
print 'Memory size:', dict(prog.allocated_mem)
|
||||
print('Program requires:', repr(prog.req_num))
|
||||
print('Cost:', 0 if prog.req_num is None else prog.req_num.cost())
|
||||
print('Memory size:', dict(prog.allocated_mem))
|
||||
|
||||
# finalize the memory
|
||||
prog.finalize_memory()
|
||||
|
||||
@@ -86,8 +86,8 @@ class HeapQ(object):
|
||||
self.size = MemValue(int_type(0))
|
||||
self.int_type = int_type
|
||||
self.basic_type = basic_type
|
||||
print 'heap: %d levels, depth %d, size %d, index size %d' % \
|
||||
(self.levels, self.depth, self.heap.oram.size, self.value_index.size)
|
||||
print('heap: %d levels, depth %d, size %d, index size %d' % \
|
||||
(self.levels, self.depth, self.heap.oram.size, self.value_index.size))
|
||||
def update(self, value, prio, for_real=True):
|
||||
self._update(self.basic_type.hard_conv(value), \
|
||||
self.basic_type.hard_conv(prio), \
|
||||
@@ -217,7 +217,7 @@ class HeapQ(object):
|
||||
|
||||
def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=sint):
|
||||
basic_type = int_type.basic_type
|
||||
vert_loops = n_loops * e_index.size / edges.size \
|
||||
vert_loops = n_loops * e_index.size // edges.size \
|
||||
if n_loops else -1
|
||||
dist = oram_type(e_index.size, entry_size=(32,log2(e_index.size)), \
|
||||
init_rounds=vert_loops, value_type=basic_type)
|
||||
@@ -287,7 +287,7 @@ def test_dijkstra(G, source, oram_type=ORAM, n_loops=None, int_type=sint):
|
||||
cint(i).print_reg('edge')
|
||||
time()
|
||||
edges[i] = edges_list[i]
|
||||
vert_loops = n_loops * e_index.size / edges.size \
|
||||
vert_loops = n_loops * e_index.size // edges.size \
|
||||
if n_loops else e_index.size
|
||||
for i in range(vert_loops):
|
||||
cint(i).print_reg('vert')
|
||||
@@ -307,7 +307,7 @@ def test_dijkstra_on_cycle(n, oram_type=ORAM, n_loops=None, int_type=sint):
|
||||
time()
|
||||
neighbour = ((i >> 1) + 2 * (i % 2) - 1 + n) % n
|
||||
edges[i] = (neighbour, 1, i % 2)
|
||||
vert_loops = n_loops * e_index.size / edges.size \
|
||||
vert_loops = n_loops * e_index.size // edges.size \
|
||||
if n_loops else e_index.size
|
||||
@for_range(vert_loops)
|
||||
def f(i):
|
||||
@@ -390,14 +390,14 @@ class ExtInt(object):
|
||||
class Vector(object):
|
||||
""" Works like a vector. """
|
||||
def __add__(self, other):
|
||||
print 'add', type(self)
|
||||
print('add', type(self))
|
||||
res = type(self)(len(self))
|
||||
@for_range(len(self))
|
||||
def f(i):
|
||||
res[i] = self[i] + other[i]
|
||||
return res
|
||||
def __sub__(self, other):
|
||||
print 'sub', type(other)
|
||||
print('sub', type(other))
|
||||
res = type(other)(len(self))
|
||||
@for_range(len(self))
|
||||
def f(i):
|
||||
@@ -412,7 +412,7 @@ class Vector(object):
|
||||
res[0] += self[i] * other[i]
|
||||
return res[0]
|
||||
else:
|
||||
print 'mul', type(self)
|
||||
print('mul', type(self))
|
||||
res = type(self)(len(self))
|
||||
@for_range_parallel(1024, len(self))
|
||||
def f(i):
|
||||
@@ -477,7 +477,7 @@ def binarymin(A):
|
||||
if len(A) == 1:
|
||||
return [1], A[0]
|
||||
else:
|
||||
half = len(A) / 2
|
||||
half = len(A) // 2
|
||||
A_prime = VectorArray(half)
|
||||
B = IntVectorArray(half)
|
||||
i = IntVectorArray(len(A))
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
from math import log, floor, ceil
|
||||
from Compiler.instructions import *
|
||||
import types
|
||||
import comparison
|
||||
import program
|
||||
import util
|
||||
from . import types
|
||||
from . import comparison
|
||||
from . import program
|
||||
from . import util
|
||||
|
||||
##
|
||||
## Helper functions for floating point arithmetic
|
||||
@@ -16,7 +16,7 @@ def two_power(n):
|
||||
else:
|
||||
max = types.cint(1) << 31
|
||||
res = 2**(n%31)
|
||||
for i in range(n / 31):
|
||||
for i in range(n // 31):
|
||||
res *= max
|
||||
return res
|
||||
|
||||
@@ -25,7 +25,7 @@ def shift_two(n, pos):
|
||||
return n >> pos
|
||||
else:
|
||||
res = (n >> (pos%63))
|
||||
for i in range(pos / 63):
|
||||
for i in range(pos // 63):
|
||||
res >>= 63
|
||||
return res
|
||||
|
||||
@@ -139,7 +139,7 @@ def PreOpL(op, items):
|
||||
kmax = 2**logk
|
||||
output = list(items)
|
||||
for i in range(logk):
|
||||
for j in range(kmax/(2**(i+1))):
|
||||
for j in range(kmax//(2**(i+1))):
|
||||
y = two_power(i) + j*two_power(i+1) - 1
|
||||
for z in range(1, 2**i+1):
|
||||
if y+z < k:
|
||||
@@ -153,7 +153,7 @@ def PreOpL2(op, items):
|
||||
op must be a binary function that outputs a new register
|
||||
"""
|
||||
k = len(items)
|
||||
half = k / 2
|
||||
half = k // 2
|
||||
output = list(items)
|
||||
if k == 0:
|
||||
return []
|
||||
@@ -161,7 +161,7 @@ def PreOpL2(op, items):
|
||||
v = PreOpL2(op, u)
|
||||
for i in range(half):
|
||||
output[2 * i + 1] = v[i]
|
||||
for i in range(1, (k + 1) / 2):
|
||||
for i in range(1, (k + 1) // 2):
|
||||
output[2 * i] = op(v[i - 1], items[2 * i])
|
||||
return output
|
||||
|
||||
@@ -185,8 +185,8 @@ def KOpL(op, a):
|
||||
if k == 1:
|
||||
return a[0]
|
||||
else:
|
||||
t1 = KOpL(op, a[:k/2])
|
||||
t2 = KOpL(op, a[k/2:])
|
||||
t1 = KOpL(op, a[:k//2])
|
||||
t2 = KOpL(op, a[k//2:])
|
||||
return op(t1, t2)
|
||||
|
||||
def KORL(a, kappa):
|
||||
@@ -195,8 +195,8 @@ def KORL(a, kappa):
|
||||
if k == 1:
|
||||
return a[0]
|
||||
else:
|
||||
t1 = KORL(a[:k/2], kappa)
|
||||
t2 = KORL(a[k/2:], kappa)
|
||||
t1 = KORL(a[:k//2], kappa)
|
||||
t2 = KORL(a[k//2:], kappa)
|
||||
return t1 + t2 - t1*t2
|
||||
|
||||
def KORC(a, kappa):
|
||||
@@ -234,7 +234,7 @@ def BitAdd(a, b, bits_to_compute=None):
|
||||
bits s[0], ... , s[k] """
|
||||
k = len(a)
|
||||
if not bits_to_compute:
|
||||
bits_to_compute = range(k)
|
||||
bits_to_compute = list(range(k))
|
||||
d = [None] * k
|
||||
for i in range(1,k):
|
||||
#assert(a[i].value == 0 or a[i].value == 1)
|
||||
@@ -248,25 +248,25 @@ def BitAdd(a, b, bits_to_compute=None):
|
||||
|
||||
# (for testing)
|
||||
def print_state():
|
||||
print 'a: ',
|
||||
print('a: ', end=' ')
|
||||
for i in range(k):
|
||||
print '%d ' % a[i].value,
|
||||
print '\nb: ',
|
||||
print('%d ' % a[i].value, end=' ')
|
||||
print('\nb: ', end=' ')
|
||||
for i in range(k):
|
||||
print '%d ' % b[i].value,
|
||||
print '\nd: ',
|
||||
print('%d ' % b[i].value, end=' ')
|
||||
print('\nd: ', end=' ')
|
||||
for i in range(k):
|
||||
print '%d ' % d[i][0].value,
|
||||
print '\n ',
|
||||
print('%d ' % d[i][0].value, end=' ')
|
||||
print('\n ', end=' ')
|
||||
for i in range(k):
|
||||
print '%d ' % d[i][1].value,
|
||||
print '\n\npg:',
|
||||
print('%d ' % d[i][1].value, end=' ')
|
||||
print('\n\npg:', end=' ')
|
||||
for i in range(k):
|
||||
print '%d ' % pg[i][0].value,
|
||||
print '\n ',
|
||||
print('%d ' % pg[i][0].value, end=' ')
|
||||
print('\n ', end=' ')
|
||||
for i in range(k):
|
||||
print '%d ' % pg[i][1].value,
|
||||
print ''
|
||||
print('%d ' % pg[i][1].value, end=' ')
|
||||
print('')
|
||||
|
||||
for bit in c:
|
||||
pass#assert(bit.value == 0 or bit.value == 1)
|
||||
@@ -281,7 +281,7 @@ def BitAdd(a, b, bits_to_compute=None):
|
||||
try:
|
||||
pass#assert(s[i].value == 0 or s[i].value == 1)
|
||||
except AssertionError:
|
||||
print '#assertion failed in BitAdd for s[%d]' % i
|
||||
print('#assertion failed in BitAdd for s[%d]' % i)
|
||||
print_state()
|
||||
s[k] = c[k-1]
|
||||
#print_state()
|
||||
@@ -316,9 +316,9 @@ def BitDecField(a, k, m, kappa, bits_to_compute=None):
|
||||
try:
|
||||
pass#assert(c.value == (2**(k + kappa) + 2**k + (a.value%2**k) - rval) % comparison.program.P)
|
||||
except AssertionError:
|
||||
print 'BitDec assertion failed'
|
||||
print 'a =', a.value
|
||||
print 'a mod 2^%d =' % k, (a.value % 2**k)
|
||||
print('BitDec assertion failed')
|
||||
print('a =', a.value)
|
||||
print('a mod 2^%d =' % k, (a.value % 2**k))
|
||||
return types.intbitint.bit_adder(list(bits(c,m)), r)
|
||||
|
||||
|
||||
@@ -503,7 +503,7 @@ def TruncPrRing(a, k, m, signed=True):
|
||||
a += types.sint.get_random_bit() << i
|
||||
return comparison.TruncLeakyInRing(a, k, m, signed=signed)
|
||||
else:
|
||||
from types import sint
|
||||
from .types import sint
|
||||
if signed:
|
||||
a += (1 << (k - 1))
|
||||
if program.Program.prog.use_trunc_pr:
|
||||
|
||||
@@ -19,10 +19,10 @@ class SparseDiGraph(object):
|
||||
if default_attributes is None:
|
||||
default_attributes = { 'merges': None, 'stop': -1, 'start': -1 }
|
||||
self.default_attributes = default_attributes
|
||||
self.attribute_pos = dict(zip(default_attributes.keys(), range(len(default_attributes))))
|
||||
self.attribute_pos = dict(list(zip(list(default_attributes.keys()), list(range(len(default_attributes))))))
|
||||
self.n = max_nodes
|
||||
# each node contains list of default attributes, followed by outoing edges
|
||||
self.nodes = [self.default_attributes.values() for i in range(self.n)]
|
||||
self.nodes = [list(self.default_attributes.values()) for i in range(self.n)]
|
||||
self.succ = [set() for i in range(self.n)]
|
||||
self.pred = [set() for i in range(self.n)]
|
||||
self.weights = {}
|
||||
@@ -45,7 +45,7 @@ class SparseDiGraph(object):
|
||||
raise CompilerError('Cannot add node %d to graph of size %d' % (i, self.n))
|
||||
node = self.nodes[i]
|
||||
|
||||
for a,value in attr.items():
|
||||
for a,value in list(attr.items()):
|
||||
if a in self.default_attributes:
|
||||
node[self.attribute_pos[a]] = value
|
||||
else:
|
||||
@@ -72,7 +72,7 @@ class SparseDiGraph(object):
|
||||
#del self.weights[(v,i)]
|
||||
#self.nodes[v].remove(i)
|
||||
self.pred[i] = []
|
||||
self.nodes[i] = self.default_attributes.values()
|
||||
self.nodes[i] = list(self.default_attributes.values())
|
||||
|
||||
def add_edge(self, i, j, weight=1):
|
||||
if j not in self[i]:
|
||||
@@ -111,7 +111,7 @@ def topological_sort(G, nbunch=None, pref=None):
|
||||
return G[node]
|
||||
else:
|
||||
def get_children(node):
|
||||
if pref.has_key(node):
|
||||
if node in pref:
|
||||
pref_set = set(pref[node])
|
||||
for i in G[node]:
|
||||
if i not in pref_set:
|
||||
@@ -123,7 +123,7 @@ def topological_sort(G, nbunch=None, pref=None):
|
||||
yield i
|
||||
|
||||
if nbunch is None:
|
||||
nbunch = reversed(range(len(G)))
|
||||
nbunch = reversed(list(range(len(G))))
|
||||
for v in nbunch: # process all vertices in G
|
||||
if v in explored:
|
||||
continue
|
||||
@@ -170,8 +170,8 @@ def reverse_dag_shortest_paths(G, source):
|
||||
dist[source] = 0
|
||||
for u in top_order:
|
||||
if u ==68273:
|
||||
print 'dist[68273]', dist[u]
|
||||
print 'pred[u]', G.pred[u]
|
||||
print('dist[68273]', dist[u])
|
||||
print('pred[u]', G.pred[u])
|
||||
if dist[u] is None:
|
||||
continue
|
||||
for v in G.pred[u]:
|
||||
@@ -207,7 +207,7 @@ def longest_paths(G, sources=None):
|
||||
G.weights[edge] = -G.weights[edge]
|
||||
dist = {}
|
||||
for source in sources:
|
||||
print ('%s, ' % source),
|
||||
print(('%s, ' % source), end=' ')
|
||||
dist[source] = dag_shortest_paths(G, source)
|
||||
#dist = johnson(G, sources)
|
||||
# reset weights
|
||||
|
||||
@@ -5,8 +5,8 @@ from Compiler import types
|
||||
|
||||
from Compiler.util import *
|
||||
|
||||
from oram import OptimalORAM,LinearORAM,RecursiveORAM,TrivialORAM,Entry
|
||||
from library import for_range,do_while,time,start_timer,stop_timer,if_,print_ln,crash,print_str
|
||||
from .oram import OptimalORAM,LinearORAM,RecursiveORAM,TrivialORAM,Entry
|
||||
from .library import for_range,do_while,time,start_timer,stop_timer,if_,print_ln,crash,print_str
|
||||
|
||||
class OMatrixRow(object):
|
||||
def __init__(self, oram, base, add_type):
|
||||
@@ -27,7 +27,7 @@ class OMatrixRow(object):
|
||||
|
||||
class OMatrix:
|
||||
def __init__(self, N, M=None, oram_type=OptimalORAM, int_type=types.sint):
|
||||
print 'matrix', oram_type
|
||||
print('matrix', oram_type)
|
||||
self.N = N
|
||||
self.M = M or N
|
||||
self.oram = oram_type(N * self.M, entry_size=log2(N), init_rounds=0, \
|
||||
@@ -73,7 +73,7 @@ class OReverseMatrix(OMatrix):
|
||||
|
||||
class OStack:
|
||||
def __init__(self, N, oram_type=OptimalORAM, int_type=types.sint):
|
||||
print 'stack', oram_type
|
||||
print('stack', oram_type)
|
||||
self.oram = oram_type(N, entry_size=log2(N), init_rounds=0, \
|
||||
value_type=int_type.basic_type)
|
||||
self.size = types.MemValue(int_type(0))
|
||||
@@ -247,4 +247,4 @@ class Matchmaker:
|
||||
self.reverse = reverse
|
||||
self.int_type = int_type
|
||||
self.basic_type = int_type.basic_type
|
||||
print 'match', self.oram_type
|
||||
print('match', self.oram_type)
|
||||
|
||||
@@ -11,7 +11,7 @@ documentation
|
||||
"""
|
||||
|
||||
import itertools
|
||||
import tools
|
||||
from . import tools
|
||||
from random import randint
|
||||
from Compiler.config import *
|
||||
from Compiler.exceptions import *
|
||||
@@ -51,7 +51,7 @@ class ldsi(base.Instruction):
|
||||
@base.vectorize
|
||||
class ldmc(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
|
||||
r""" Assigns register $c_i$ the value in memory \verb+C[n]+. """
|
||||
__slots__ = ["code"]
|
||||
__slots__ = []
|
||||
code = base.opcodes['LDMC']
|
||||
arg_format = ['cw','int']
|
||||
|
||||
@@ -62,7 +62,7 @@ class ldmc(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
|
||||
@base.vectorize
|
||||
class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
|
||||
r""" Assigns register $s_i$ the value in memory \verb+S[n]+. """
|
||||
__slots__ = ["code"]
|
||||
__slots__ = []
|
||||
code = base.opcodes['LDMS']
|
||||
arg_format = ['sw','int']
|
||||
|
||||
@@ -73,7 +73,7 @@ class ldms(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
|
||||
@base.vectorize
|
||||
class stmc(base.DirectMemoryWriteInstruction):
|
||||
r""" Sets \verb+C[n]+ to be the value $c_i$. """
|
||||
__slots__ = ["code"]
|
||||
__slots__ = []
|
||||
code = base.opcodes['STMC']
|
||||
arg_format = ['c','int']
|
||||
|
||||
@@ -84,7 +84,7 @@ class stmc(base.DirectMemoryWriteInstruction):
|
||||
@base.vectorize
|
||||
class stms(base.DirectMemoryWriteInstruction):
|
||||
r""" Sets \verb+S[n]+ to be the value $s_i$. """
|
||||
__slots__ = ["code"]
|
||||
__slots__ = []
|
||||
code = base.opcodes['STMS']
|
||||
arg_format = ['s','int']
|
||||
|
||||
@@ -94,7 +94,7 @@ class stms(base.DirectMemoryWriteInstruction):
|
||||
@base.vectorize
|
||||
class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
|
||||
r""" Assigns register $ci_i$ the value in memory \verb+Ci[n]+. """
|
||||
__slots__ = ["code"]
|
||||
__slots__ = []
|
||||
code = base.opcodes['LDMINT']
|
||||
arg_format = ['ciw','int']
|
||||
|
||||
@@ -104,7 +104,7 @@ class ldmint(base.DirectMemoryInstruction, base.ReadMemoryInstruction):
|
||||
@base.vectorize
|
||||
class stmint(base.DirectMemoryWriteInstruction):
|
||||
r""" Sets \verb+Ci[n]+ to be the value $ci_i$. """
|
||||
__slots__ = ["code"]
|
||||
__slots__ = []
|
||||
code = base.opcodes['STMINT']
|
||||
arg_format = ['ci','int']
|
||||
|
||||
@@ -227,7 +227,7 @@ class protectmemint(base.Instruction):
|
||||
@base.vectorize
|
||||
class movc(base.Instruction):
|
||||
r""" Assigns register $c_i$ the value in the register $c_j$. """
|
||||
__slots__ = ["code"]
|
||||
__slots__ = []
|
||||
code = base.opcodes['MOVC']
|
||||
arg_format = ['cw','c']
|
||||
|
||||
@@ -238,7 +238,7 @@ class movc(base.Instruction):
|
||||
@base.vectorize
|
||||
class movs(base.Instruction):
|
||||
r""" Assigns register $s_i$ the value in the register $s_j$. """
|
||||
__slots__ = ["code"]
|
||||
__slots__ = []
|
||||
code = base.opcodes['MOVS']
|
||||
arg_format = ['sw','s']
|
||||
|
||||
@@ -248,7 +248,7 @@ class movs(base.Instruction):
|
||||
@base.vectorize
|
||||
class movint(base.Instruction):
|
||||
r""" Assigns register $ci_i$ the value in the register $ci_j$. """
|
||||
__slots__ = ["code"]
|
||||
__slots__ = []
|
||||
code = base.opcodes['MOVINT']
|
||||
arg_format = ['ciw','ci']
|
||||
|
||||
@@ -347,6 +347,21 @@ class use_prep(base.Instruction):
|
||||
code = base.opcodes['USE_PREP']
|
||||
arg_format = ['str','int']
|
||||
|
||||
class nplayers(base.Instruction):
|
||||
r""" Number of players """
|
||||
code = base.opcodes['NPLAYERS']
|
||||
arg_format = ['ciw']
|
||||
|
||||
class threshold(base.Instruction):
|
||||
r""" Maximal number of corrupt players """
|
||||
code = base.opcodes['THRESHOLD']
|
||||
arg_format = ['ciw']
|
||||
|
||||
class playerid(base.Instruction):
|
||||
r""" My player number """
|
||||
code = base.opcodes['PLAYERID']
|
||||
arg_format = ['ciw']
|
||||
|
||||
###
|
||||
### Basic arithmetic
|
||||
###
|
||||
@@ -738,7 +753,7 @@ class shrci(base.ClearShiftInstruction):
|
||||
class triple(base.DataInstruction):
|
||||
r""" Load secret variables $s_i$, $s_j$ and $s_k$
|
||||
with the next multiplication triple. """
|
||||
__slots__ = ['data_type']
|
||||
__slots__ = []
|
||||
code = base.opcodes['TRIPLE']
|
||||
arg_format = ['sw','sw','sw']
|
||||
data_type = 'triple'
|
||||
@@ -752,7 +767,7 @@ class triple(base.DataInstruction):
|
||||
class gbittriple(base.DataInstruction):
|
||||
r""" Load secret variables $s_i$, $s_j$ and $s_k$
|
||||
with the next GF(2) multiplication triple. """
|
||||
__slots__ = ['data_type']
|
||||
__slots__ = []
|
||||
code = base.opcodes['GBITTRIPLE']
|
||||
arg_format = ['sgw','sgw','sgw']
|
||||
data_type = 'bittriple'
|
||||
@@ -1400,7 +1415,8 @@ class convmodp(base.Instruction):
|
||||
bitlength = program.bit_length if bitlength is None else bitlength
|
||||
if bitlength > 64:
|
||||
raise CompilerError('%d-bit conversion requested ' \
|
||||
'but integer registers only have 64 bits')
|
||||
'but integer registers only have 64 bits' % \
|
||||
bitlength)
|
||||
super(convmodp_class, self).__init__(*(args + (bitlength,)))
|
||||
|
||||
@base.vectorize
|
||||
@@ -1433,7 +1449,7 @@ class muls(base.VarArgsInstruction, base.DataInstruction):
|
||||
data_type = 'triple'
|
||||
|
||||
def get_repeat(self):
|
||||
return len(self.args) / 3
|
||||
return len(self.args) // 3
|
||||
|
||||
def merge_id(self):
|
||||
# can merge different sizes
|
||||
@@ -1508,6 +1524,8 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction):
|
||||
for j in range(self.args[i] - 2):
|
||||
yield 's' + field
|
||||
|
||||
gf2n_arg_format = arg_format
|
||||
|
||||
def bases(self):
|
||||
i = 0
|
||||
while i < len(self.args):
|
||||
@@ -1515,7 +1533,7 @@ class dotprods(base.VarArgsInstruction, base.DataInstruction):
|
||||
i += self.args[i]
|
||||
|
||||
def get_repeat(self):
|
||||
return sum(self.args[i] / 2 for i in self.bases()) * self.get_size()
|
||||
return sum(self.args[i] // 2 for i in self.bases()) * self.get_size()
|
||||
|
||||
def get_def(self):
|
||||
return [self.args[i + 1] for i in self.bases()]
|
||||
@@ -1567,7 +1585,7 @@ class lts(base.CISC):
|
||||
arg_format = ['sw', 's', 's', 'int', 'int']
|
||||
|
||||
def expand(self):
|
||||
from types import sint
|
||||
from .types import sint
|
||||
a = sint()
|
||||
subs(a, self.args[1], self.args[2])
|
||||
comparison.LTZ(self.args[0], a, self.args[3], self.args[4])
|
||||
|
||||
@@ -56,6 +56,9 @@ opcodes = dict(
|
||||
USE_PREP = 0x1C,
|
||||
STARTGRIND = 0x1D,
|
||||
STOPGRIND = 0x1E,
|
||||
NPLAYERS = 0xE2,
|
||||
THRESHOLD = 0xE3,
|
||||
PLAYERID = 0xE4,
|
||||
# Addition
|
||||
ADDC = 0x20,
|
||||
ADDS = 0x21,
|
||||
@@ -277,7 +280,7 @@ def gf2n(instruction):
|
||||
vectorized GF_2^n instruction if a modp version exists. """
|
||||
global_dict = inspect.getmodule(instruction).__dict__
|
||||
|
||||
if global_dict.has_key('v' + instruction.__name__):
|
||||
if 'v' + instruction.__name__ in global_dict:
|
||||
vectorized = True
|
||||
else:
|
||||
vectorized = False
|
||||
@@ -316,7 +319,7 @@ def gf2n(instruction):
|
||||
if 'gf2n_arg_format' in instruction_cls.__dict__:
|
||||
arg_format = instruction_cls.gf2n_arg_format
|
||||
elif isinstance(instruction_cls.arg_format, itertools.repeat):
|
||||
__f = instruction_cls.arg_format.next()
|
||||
__f = next(instruction_cls.arg_format)
|
||||
if __f != 'int' and __f != 'p':
|
||||
arg_format = itertools.repeat(__f[0] + 'g' + __f[1:])
|
||||
else:
|
||||
@@ -420,7 +423,7 @@ class ClearIntAF(RegisterArgFormat):
|
||||
class IntArgFormat(ArgFormat):
|
||||
@classmethod
|
||||
def check(cls, arg):
|
||||
if not isinstance(arg, (int, long)):
|
||||
if not isinstance(arg, int):
|
||||
raise ArgumentError(arg, 'Expected an integer-valued argument')
|
||||
|
||||
@classmethod
|
||||
@@ -512,7 +515,7 @@ class Instruction(object):
|
||||
|
||||
Instruction.count += 1
|
||||
if Instruction.count % 100000 == 0:
|
||||
print "Compiled %d lines at" % self.__class__.count, time.asctime()
|
||||
print("Compiled %d lines at" % self.__class__.count, time.asctime())
|
||||
|
||||
def get_code(self):
|
||||
return self.code
|
||||
@@ -535,7 +538,7 @@ class Instruction(object):
|
||||
|
||||
def check_args(self):
|
||||
""" Check the args match up with that specified in arg_format """
|
||||
for n,(arg,f) in enumerate(itertools.izip_longest(self.args, self.arg_format)):
|
||||
for n,(arg,f) in enumerate(itertools.zip_longest(self.args, self.arg_format)):
|
||||
if arg is None:
|
||||
if not isinstance(self.arg_format, (list, tuple)):
|
||||
break # end of optional arguments
|
||||
@@ -758,8 +761,6 @@ class ClearShiftInstruction(ClearImmediate):
|
||||
else:
|
||||
# assume 64-bit machine
|
||||
bits = 63
|
||||
elif program.options.ring:
|
||||
bits = int(program.options.ring) - 1
|
||||
if self.args[2] > bits:
|
||||
raise CompilerError('Shifting by more than %d bits '
|
||||
'not implemented' % bits)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat, _single
|
||||
from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat, _single, localint
|
||||
from Compiler.instructions import *
|
||||
from Compiler.util import tuplify,untuplify
|
||||
from Compiler import instructions,instructions_base,comparison,program,util
|
||||
@@ -6,6 +6,7 @@ import inspect,math
|
||||
import random
|
||||
import collections
|
||||
import operator
|
||||
from functools import reduce
|
||||
|
||||
def get_program():
|
||||
return instructions.program
|
||||
@@ -101,6 +102,8 @@ def print_ln_if(cond, ss, *args):
|
||||
else:
|
||||
subs = ss.split('%s')
|
||||
assert len(subs) == len(args) + 1
|
||||
if isinstance(cond, localint):
|
||||
cond = cond._v
|
||||
cond = cint.conv(cond)
|
||||
for i, s in enumerate(subs):
|
||||
if i != 0:
|
||||
@@ -274,7 +277,7 @@ class FunctionTapeCall:
|
||||
def join(self):
|
||||
self.thread.join()
|
||||
instructions.program.free(self.base, 'ci')
|
||||
for reg_type,addr in self.bases.iteritems():
|
||||
for reg_type,addr in self.bases.items():
|
||||
get_program().free(addr, reg_type.reg_type)
|
||||
|
||||
class Function:
|
||||
@@ -287,7 +290,7 @@ class Function:
|
||||
self.compile_args = compile_args
|
||||
def __call__(self, *args):
|
||||
args = tuple(arg.read() if isinstance(arg, MemValue) else arg for arg in args)
|
||||
get_reg_type = lambda x: regint if isinstance(x, (int, long)) else type(x)
|
||||
get_reg_type = lambda x: regint if isinstance(x, int) else type(x)
|
||||
if len(args) not in self.type_args:
|
||||
# first call
|
||||
type_args = collections.defaultdict(list)
|
||||
@@ -296,9 +299,11 @@ class Function:
|
||||
def wrapped_function(*compile_args):
|
||||
base = get_arg()
|
||||
bases = dict((t, regint.load_mem(base + i)) \
|
||||
for i,t in enumerate(sorted(type_args)))
|
||||
for i,t in enumerate(sorted(type_args,
|
||||
key=lambda x:
|
||||
x.reg_type)))
|
||||
runtime_args = [None] * len(args)
|
||||
for t in sorted(type_args):
|
||||
for t in sorted(type_args, key=lambda x: x.reg_type):
|
||||
for i,i_arg in enumerate(type_args[t]):
|
||||
runtime_args[i_arg] = t.load_mem(bases[t] + i)
|
||||
return self.function(*(list(compile_args) + runtime_args))
|
||||
@@ -308,7 +313,8 @@ class Function:
|
||||
base = instructions.program.malloc(len(type_args), 'ci')
|
||||
bases = dict((t, get_program().malloc(len(type_args[t]), t)) \
|
||||
for t in type_args)
|
||||
for i,reg_type in enumerate(sorted(type_args)):
|
||||
for i,reg_type in enumerate(sorted(type_args,
|
||||
key=lambda x: x.reg_type)):
|
||||
store_in_mem(bases[reg_type], base + i)
|
||||
for j,i_arg in enumerate(type_args[reg_type]):
|
||||
if get_reg_type(args[i_arg]) != reg_type:
|
||||
@@ -353,13 +359,13 @@ class FunctionBlock(Function):
|
||||
block.alloc_pool = defaultdict(set)
|
||||
del parent_node.children[-1]
|
||||
self.node = get_tape().req_node
|
||||
print 'Compiling function', self.name
|
||||
print('Compiling function', self.name)
|
||||
result = wrapped_function(*self.compile_args)
|
||||
if result is not None:
|
||||
self.result = memorize(result)
|
||||
else:
|
||||
self.result = None
|
||||
print 'Done compiling function', self.name
|
||||
print('Done compiling function', self.name)
|
||||
p_return_address = get_tape().program.malloc(1, 'ci')
|
||||
get_tape().function_basicblocks[block] = p_return_address
|
||||
return_address = regint.load_mem(p_return_address)
|
||||
@@ -429,7 +435,7 @@ def sort(a):
|
||||
res = a
|
||||
|
||||
for i in range(len(a)):
|
||||
for j in reversed(range(i)):
|
||||
for j in reversed(list(range(i))):
|
||||
res[j], res[j+1] = cond_swap(res[j], res[j+1])
|
||||
|
||||
return res
|
||||
@@ -443,7 +449,7 @@ def odd_even_merge(a):
|
||||
odd_even_merge(even)
|
||||
odd_even_merge(odd)
|
||||
a[0] = even[0]
|
||||
for i in range(1, len(a) / 2):
|
||||
for i in range(1, len(a) // 2):
|
||||
a[2*i-1], a[2*i] = cond_swap(odd[i-1], even[i])
|
||||
a[-1] = odd[-1]
|
||||
|
||||
@@ -451,8 +457,8 @@ def odd_even_merge_sort(a):
|
||||
if len(a) == 1:
|
||||
return
|
||||
elif len(a) % 2 == 0:
|
||||
lower = a[:len(a)/2]
|
||||
upper = a[len(a)/2:]
|
||||
lower = a[:len(a)//2]
|
||||
upper = a[len(a)//2:]
|
||||
odd_even_merge_sort(lower)
|
||||
odd_even_merge_sort(upper)
|
||||
a[:] = lower + upper
|
||||
@@ -472,10 +478,10 @@ def chunky_odd_even_merge_sort(a):
|
||||
def round():
|
||||
for i in range(len(a)):
|
||||
a[i] = type(a[i]).load_mem(i * a[i].sizeof())
|
||||
for i in range(len(a) / l):
|
||||
for j in range(l / k):
|
||||
for i in range(len(a) // l):
|
||||
for j in range(l // k):
|
||||
base = i * l + j
|
||||
step = l / k
|
||||
step = l // k
|
||||
if k == 2:
|
||||
a[base], a[base+step] = cond_swap(a[base], a[base+step])
|
||||
else:
|
||||
@@ -514,7 +520,7 @@ def chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7, use
|
||||
def run_chunk(size, base):
|
||||
if size not in chunks:
|
||||
def swap_list(list_base):
|
||||
for i in range(size / 2):
|
||||
for i in range(size // 2):
|
||||
base = list_base + 2 * i
|
||||
x, y = cond_swap(load_secret_mem(base),
|
||||
load_secret_mem(base + 1))
|
||||
@@ -526,8 +532,8 @@ def chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7, use
|
||||
def run_round(size):
|
||||
# minimize number of chunk sizes
|
||||
n_chunks = int(math.ceil(1.0 * size / max_chunk_size))
|
||||
lower_size = size / n_chunks / 2 * 2
|
||||
n_lower_size = n_chunks - (size - n_chunks * lower_size) / 2
|
||||
lower_size = size // n_chunks // 2 * 2
|
||||
n_lower_size = n_chunks - (size - n_chunks * lower_size) // 2
|
||||
# print len(to_swap) == lower_size * n_lower_size + \
|
||||
# (lower_size + 2) * (n_chunks - n_lower_size), \
|
||||
# len(to_swap), n_chunks, lower_size, n_lower_size
|
||||
@@ -603,10 +609,10 @@ def chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7, use
|
||||
k *= 2
|
||||
size = 0
|
||||
instructions.program.curr_tape.merge_opens = False
|
||||
for i in range(n / l):
|
||||
for j in range(l / k):
|
||||
for i in range(n // l):
|
||||
for j in range(l // k):
|
||||
base = i * l + j
|
||||
step = l / k
|
||||
step = l // k
|
||||
size += run_setup(k, a_base + base, step, tmp_base + size)
|
||||
run_threads_in_rounds(pre_threads)
|
||||
run_round(size)
|
||||
@@ -651,7 +657,7 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=
|
||||
def run_chunk(size, base):
|
||||
if size not in chunks:
|
||||
def swap_list(list_base):
|
||||
for i in range(size / 2):
|
||||
for i in range(size // 2):
|
||||
base = list_base + 2 * i
|
||||
x, y = cond_swap(load_secret_mem(base),
|
||||
load_secret_mem(base + 1))
|
||||
@@ -663,8 +669,8 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=
|
||||
def run_round(size):
|
||||
# minimize number of chunk sizes
|
||||
n_chunks = int(math.ceil(1.0 * size / max_chunk_size))
|
||||
lower_size = size / n_chunks / 2 * 2
|
||||
n_lower_size = n_chunks - (size - n_chunks * lower_size) / 2
|
||||
lower_size = size // n_chunks // 2 * 2
|
||||
n_lower_size = n_chunks - (size - n_chunks * lower_size) // 2
|
||||
# print len(to_swap) == lower_size * n_lower_size + \
|
||||
# (lower_size + 2) * (n_chunks - n_lower_size), \
|
||||
# len(to_swap), n_chunks, lower_size, n_lower_size
|
||||
@@ -692,7 +698,7 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=
|
||||
def outer(i):
|
||||
def inner(j):
|
||||
base = j
|
||||
step = l / k
|
||||
step = l // k
|
||||
if k == 2:
|
||||
tmp_addr = regint.load_mem(tmp_i)
|
||||
load_and_store(base, tmp_addr)
|
||||
@@ -704,19 +710,19 @@ def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=
|
||||
load_and_store(m, tmp_addr)
|
||||
store_in_mem(tmp_addr + 1, tmp_i)
|
||||
range_loop(inner2, base + step, base + (k - 1) * step, step)
|
||||
range_loop(inner, a_base + i * l, a_base + i * l + l / k)
|
||||
range_loop(inner, a_base + i * l, a_base + i * l + l // k)
|
||||
instructions.program.curr_tape.merge_opens = False
|
||||
to_tmp = True
|
||||
store_in_mem(tmp_base, tmp_i)
|
||||
range_loop(outer, n / l)
|
||||
range_loop(outer, n // l)
|
||||
if k == 2:
|
||||
run_round(n)
|
||||
else:
|
||||
run_round(n / k * (k - 2))
|
||||
run_round(n // k * (k - 2))
|
||||
instructions.program.curr_tape.merge_opens = False
|
||||
to_tmp = False
|
||||
store_in_mem(tmp_base, tmp_i)
|
||||
range_loop(outer, n / l)
|
||||
range_loop(outer, n // l)
|
||||
|
||||
if isinstance(a, list):
|
||||
instructions.program.restart_main_thread()
|
||||
@@ -734,15 +740,15 @@ def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32):
|
||||
k = 1
|
||||
while k < l:
|
||||
k *= 2
|
||||
n_outer = len(a) / l
|
||||
n_inner = l / k
|
||||
n_innermost = 1 if k == 2 else k / 2 - 1
|
||||
@for_range_parallel(n_parallel / n_innermost / n_inner, n_outer)
|
||||
n_outer = len(a) // l
|
||||
n_inner = l // k
|
||||
n_innermost = 1 if k == 2 else k // 2 - 1
|
||||
@for_range_parallel(n_parallel // n_innermost // n_inner, n_outer)
|
||||
def loop(i):
|
||||
@for_range_parallel(n_parallel / n_innermost, n_inner)
|
||||
@for_range_parallel(n_parallel // n_innermost, n_inner)
|
||||
def inner(j):
|
||||
base = i*l + j
|
||||
step = l/k
|
||||
step = l//k
|
||||
if k == 2:
|
||||
a[base], a[base+step] = cond_swap(a[base], a[base+step])
|
||||
else:
|
||||
@@ -805,7 +811,7 @@ def range_loop(loop_body, start, stop=None, step=None):
|
||||
# known loop count
|
||||
if condition(start):
|
||||
get_tape().req_node.children[-1].aggregator = \
|
||||
lambda x: ((stop - start) / step) * x[0]
|
||||
lambda x: ((stop - start) // step) * x[0]
|
||||
|
||||
def for_range(start, stop=None, step=None):
|
||||
""" Execute loop bodies consecutively """
|
||||
@@ -840,7 +846,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
my_n_parallel = n_parallel
|
||||
if isinstance(n_parallel, int):
|
||||
if isinstance(n_loops, int):
|
||||
loop_rounds = n_loops / n_parallel \
|
||||
loop_rounds = n_loops // n_parallel \
|
||||
if n_parallel < n_loops else 0
|
||||
else:
|
||||
loop_rounds = n_loops / n_parallel
|
||||
@@ -884,7 +890,7 @@ def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
|
||||
regint.push(k)
|
||||
return i + k
|
||||
my_n_parallel = n_opt_loops
|
||||
loop_rounds = n_loops / my_n_parallel
|
||||
loop_rounds = n_loops // my_n_parallel
|
||||
blocks = get_tape().basicblocks
|
||||
n_to_merge = 5
|
||||
if loop_rounds == 1 and parent_block is blocks[-n_to_merge]:
|
||||
@@ -966,7 +972,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
indices = []
|
||||
for n in reversed(split):
|
||||
indices.insert(0, i % n)
|
||||
i /= n
|
||||
i //= n
|
||||
return loop_body(*indices)
|
||||
return new_body
|
||||
new_dec = map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, thread_mem_req)
|
||||
@@ -979,7 +985,7 @@ def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
||||
else:
|
||||
return dec
|
||||
def decorator(loop_body):
|
||||
thread_rounds = n_loops / n_threads
|
||||
thread_rounds = n_loops // n_threads
|
||||
remainder = n_loops % n_threads
|
||||
for t in thread_mem_req:
|
||||
if t != regint:
|
||||
@@ -1233,10 +1239,25 @@ def stop_timer(timer_id=0):
|
||||
stop(timer_id)
|
||||
get_tape().start_new_basicblock(name='post-stop-timer')
|
||||
|
||||
def get_number_of_players():
|
||||
res = regint()
|
||||
nplayers(res)
|
||||
return res
|
||||
|
||||
def get_threshold():
|
||||
res = regint()
|
||||
threshold(res)
|
||||
return res
|
||||
|
||||
def get_player_id():
|
||||
res = localint()
|
||||
playerid(res._v)
|
||||
return res
|
||||
|
||||
# Fixed point ops
|
||||
|
||||
from math import ceil, log
|
||||
from floatingpoint import PreOR, TruncPr, two_power, shift_two
|
||||
from .floatingpoint import PreOR, TruncPr, two_power, shift_two
|
||||
|
||||
def approximate_reciprocal(divisor, k, f, theta):
|
||||
"""
|
||||
@@ -1369,7 +1390,7 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
|
||||
# no probabilistic truncation in binary circuits
|
||||
nearest = True
|
||||
res_f = f
|
||||
f = max((k - nearest) / 2 + 1, f)
|
||||
f = max((k - nearest) // 2 + 1, f)
|
||||
assert 2 * f > k - nearest
|
||||
theta = int(ceil(log(k/3.5) / log(2)))
|
||||
alpha = b.get_type(2 * k).two_power(2*f)
|
||||
@@ -1387,7 +1408,7 @@ def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
|
||||
x = x.round(2*k, 2*f, kappa, nearest, signed=True)
|
||||
|
||||
y = y.extend(2 * k) * (alpha + x).extend(2 * k)
|
||||
y = y.round(k + 2 * f, 3 * f - res_f, kappa, nearest, signed=True)
|
||||
y = y.round(k + 3 * f - res_f, 3 * f - res_f, kappa, nearest, signed=True)
|
||||
return y
|
||||
def AppRcr(b, k, f, kappa, simplex_flag=False, nearest=False):
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,7 @@ import mpc_math, math
|
||||
from Compiler.types import *
|
||||
from Compiler.types import _unreduced_squant
|
||||
from Compiler.library import *
|
||||
from functools import reduce
|
||||
|
||||
def log_e(x):
|
||||
return mpc_math.log_fx(x, math.e)
|
||||
@@ -129,7 +130,7 @@ class DenseBase(Layer):
|
||||
tmp[j][k] = sfix.unreduced_dot_product(a, b)
|
||||
|
||||
if self.d_in * self.d_out < 100000:
|
||||
print 'reduce at once'
|
||||
print('reduce at once')
|
||||
@multithread(self.n_threads, self.d_in * self.d_out)
|
||||
def _(base, size):
|
||||
self.nabla_W.assign_vector(
|
||||
@@ -386,7 +387,7 @@ class QuantConvBase(QuantBase):
|
||||
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
|
||||
self.weights.input_from(player, budget=100000)
|
||||
self.bias.input_from(player)
|
||||
print 'WARNING: assuming that bias quantization parameters are correct'
|
||||
print('WARNING: assuming that bias quantization parameters are correct')
|
||||
|
||||
self.output_squant.params.precompute(self.input_squant.params, self.weight_squant.params)
|
||||
|
||||
@@ -404,7 +405,7 @@ class QuantConvBase(QuantBase):
|
||||
start_timer(2)
|
||||
n_outputs = reduce(operator.mul, self.output_shape)
|
||||
if n_outputs % self.n_threads == 0:
|
||||
n_per_thread = n_outputs / self.n_threads
|
||||
n_per_thread = n_outputs // self.n_threads
|
||||
@for_range_opt_multithread(self.n_threads, self.n_threads)
|
||||
def _(i):
|
||||
res = _unreduced_squant(
|
||||
@@ -556,7 +557,7 @@ class QuantAveragePool2d(QuantBase):
|
||||
self.filter_size = filter_size
|
||||
|
||||
def input_from(self, player):
|
||||
print 'WARNING: assuming that input and output quantization parameters are the same'
|
||||
print('WARNING: assuming that input and output quantization parameters are the same')
|
||||
for s in self.input_squant, self.output_squant:
|
||||
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
|
||||
|
||||
@@ -567,7 +568,7 @@ class QuantAveragePool2d(QuantBase):
|
||||
_, output_h, output_w, n_channels_out = self.output_shape
|
||||
|
||||
n = input_h * input_w
|
||||
print 'divisor: ', n
|
||||
print('divisor: ', n)
|
||||
|
||||
assert output_h == output_w == 1
|
||||
assert n_channels_in == n_channels_out
|
||||
@@ -599,7 +600,7 @@ class QuantAveragePool2d(QuantBase):
|
||||
acc += self.X[0][in_y][in_x][c].v
|
||||
#fc += 1
|
||||
logn = int(math.log(n, 2))
|
||||
acc = (acc + n / 2)
|
||||
acc = (acc + n // 2)
|
||||
if 2 ** logn == n:
|
||||
acc = acc.round(self.output_squant.params.k + logn,
|
||||
logn, nearest=True)
|
||||
@@ -614,7 +615,7 @@ class QuantReshape(QuantBase):
|
||||
super(QuantReshape, self).__init__(input_shape, output_shape)
|
||||
|
||||
def input_from(self, player):
|
||||
print 'WARNING: assuming that input and output quantization parameters are the same'
|
||||
print('WARNING: assuming that input and output quantization parameters are the same')
|
||||
_ = self.new_squant()
|
||||
for s in self.input_squant, _, self.output_squant:
|
||||
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
|
||||
@@ -628,7 +629,7 @@ class QuantReshape(QuantBase):
|
||||
|
||||
class QuantSoftmax(QuantBase):
|
||||
def input_from(self, player):
|
||||
print 'WARNING: assuming that input and output quantization parameters are the same'
|
||||
print('WARNING: assuming that input and output quantization parameters are the same')
|
||||
for s in self.input_squant, self.output_squant:
|
||||
s.set_params(sfloat.get_input_from(player), sint.get_input_from(player))
|
||||
|
||||
@@ -666,14 +667,14 @@ class Optimizer:
|
||||
N = self.layers[0].N
|
||||
assert self.layers[-1].N == N
|
||||
assert N % 2 == 0
|
||||
n = N / 2
|
||||
n = N // 2
|
||||
@for_range(n)
|
||||
def _(i):
|
||||
self.layers[-1].Y[i] = 0
|
||||
self.layers[-1].Y[i + n] = 1
|
||||
n_per_epoch = int(math.ceil(1. * max(len(X) for X in
|
||||
self.X_by_label) / n))
|
||||
print '%d runs per epoch' % n_per_epoch
|
||||
print('%d runs per epoch' % n_per_epoch)
|
||||
indices_by_label = []
|
||||
for label, X in enumerate(self.X_by_label):
|
||||
indices = regint.Array(n * n_per_epoch)
|
||||
@@ -794,8 +795,8 @@ class SGD(Optimizer):
|
||||
x = x.reveal()
|
||||
print_ln_if((x > 1000) + (x < -1000),
|
||||
name + ': %s %s %s %s',
|
||||
*[y.v.reveal() for y in old, red_old, \
|
||||
new, diff])
|
||||
*[y.v.reveal() for y in (old, red_old, \
|
||||
new, diff)])
|
||||
if self.debug:
|
||||
d = delta_theta.get_vector().reveal()
|
||||
a = cfix.Array(len(d.v))
|
||||
|
||||
@@ -474,7 +474,7 @@ def norm_simplified_SQ(b, k):
|
||||
m_odd = m_odd + z[i]
|
||||
|
||||
# construct w,
|
||||
k_over_2 = k / 2 + 1
|
||||
k_over_2 = k // 2 + 1
|
||||
w_array = [0] * (k_over_2)
|
||||
w_array[0] = z[0]
|
||||
for i in range(1, k_over_2):
|
||||
@@ -510,7 +510,7 @@ def sqrt_simplified_fx(x):
|
||||
m_odd = (1 - 2 * m_odd) + m_odd
|
||||
w = (w * 2 - w) * (1-m_odd) + w
|
||||
# map number to use sfix format and instantiate the number
|
||||
w = types.sfix(w * 2 ** ((x.f - (x.f % 2)) / 2))
|
||||
w = types.sfix(w * 2 ** ((x.f - (x.f % 2)) // 2))
|
||||
# obtains correct 2 ** (m/2)
|
||||
w = (w * (types.cfix(2 ** (1/2.0))) - w) * m_odd + w
|
||||
# produce x/ 2^(m/2)
|
||||
|
||||
139
Compiler/oram.py
139
Compiler/oram.py
@@ -4,6 +4,7 @@ import collections
|
||||
import itertools
|
||||
import operator
|
||||
import sys
|
||||
from functools import reduce
|
||||
|
||||
from Compiler.types import *
|
||||
from Compiler.types import _secret
|
||||
@@ -95,7 +96,7 @@ class gf2nBlock(Block):
|
||||
prod_bits = [start * bit for bit in value_bits]
|
||||
anti_bits = [v - p for v,p in zip(value_bits,prod_bits)]
|
||||
self.lower = sum(bit << i for i,bit in enumerate(prod_bits[:length]))
|
||||
self.bits = map(operator.add, anti_bits[:length], prod_bits[length:]) + \
|
||||
self.bits = list(map(operator.add, anti_bits[:length], prod_bits[length:])) + \
|
||||
anti_bits[length:]
|
||||
self.adjust = if_else(start, 1 << length, cgf2n(1))
|
||||
elif entries_per_block < 4:
|
||||
@@ -105,7 +106,7 @@ class gf2nBlock(Block):
|
||||
choice_bits = demux(start_bits)
|
||||
inv_bits = [1 - bit for bit in floatingpoint.PreOR(choice_bits, None)]
|
||||
mask_bits = sum(([x] * length for x in inv_bits), [])
|
||||
lower_bits = map(operator.mul, value_bits, mask_bits)
|
||||
lower_bits = list(map(operator.mul, value_bits, mask_bits))
|
||||
self.lower = sum(bit << i for i,bit in enumerate(lower_bits))
|
||||
self.bits = [sum(map(operator.mul, choice_bits, value_bits[i::length])) \
|
||||
for i in range(length)]
|
||||
@@ -124,7 +125,7 @@ class gf2nBlock(Block):
|
||||
pre_bits = floatingpoint.PreOpL(lambda x,y,z=None: x + y, bits)
|
||||
inv_bits = [1 - bit for bit in pre_bits]
|
||||
mask_bits = sum(([x] * length for x in inv_bits), [])
|
||||
lower_bits = map(operator.mul, value_bits, mask_bits)
|
||||
lower_bits = list(map(operator.mul, value_bits, mask_bits))
|
||||
masked = self.value - sum(bit << i for i,bit in enumerate(lower_bits))
|
||||
self.lower = sum(bit << i for i,bit in enumerate(lower_bits))
|
||||
self.bits = (masked / adjust).bit_decompose(used_bits)
|
||||
@@ -177,12 +178,12 @@ def demux_list(x):
|
||||
return [1]
|
||||
elif n == 1:
|
||||
return [1 - x[0], x[0]]
|
||||
a = demux_list(x[:n/2])
|
||||
b = demux_list(x[n/2:])
|
||||
a = demux_list(x[:n//2])
|
||||
b = demux_list(x[n//2:])
|
||||
n_a = len(a)
|
||||
a *= len(b)
|
||||
b = reduce(operator.add, ([i] * n_a for i in b))
|
||||
res = map(operator.mul, a, b)
|
||||
res = list(map(operator.mul, a, b))
|
||||
return res
|
||||
|
||||
def demux_array(x, res=None):
|
||||
@@ -193,12 +194,12 @@ def demux_array(x, res=None):
|
||||
res[0] = 1 - x[0]
|
||||
res[1] = x[0]
|
||||
else:
|
||||
a = Array(2**(n/2), type(x[0]))
|
||||
a.assign(demux(x[:n/2]))
|
||||
b = Array(2**(n-n/2), type(x[0]))
|
||||
b.assign(demux(x[n/2:]))
|
||||
a = Array(2**(n//2), type(x[0]))
|
||||
a.assign(demux(x[:n//2]))
|
||||
b = Array(2**(n-n//2), type(x[0]))
|
||||
b.assign(demux(x[n//2:]))
|
||||
@for_range_multithread(get_n_threads(len(res)), \
|
||||
max(1, n_parallel / len(b)), len(a))
|
||||
max(1, n_parallel // len(b)), len(a))
|
||||
def f(i):
|
||||
@for_range_parallel(n_parallel, len(b))
|
||||
def f(j):
|
||||
@@ -234,7 +235,7 @@ class Value(object):
|
||||
return Value(other * self.value, other * self.empty)
|
||||
__rmul__ = __mul__
|
||||
def equal(self, other, length=None):
|
||||
if isinstance(other, (int, long)) and isinstance(self.value, (int, long)):
|
||||
if isinstance(other, int) and isinstance(self.value, int):
|
||||
return (1 - self.empty) * (other == self.value)
|
||||
return (1 - self.empty) * self.value.equal(other, length)
|
||||
def reveal(self):
|
||||
@@ -252,9 +253,9 @@ class Value(object):
|
||||
try:
|
||||
value = self.empty
|
||||
while True:
|
||||
if value in (1, 1L):
|
||||
if value == 1:
|
||||
return '<>'
|
||||
if value in (0, 0L):
|
||||
if value == 0:
|
||||
return '<%s>' % str(self.value)
|
||||
value = value.value
|
||||
except:
|
||||
@@ -297,8 +298,8 @@ class Entry(object):
|
||||
self.created_non_empty = False
|
||||
if x is None:
|
||||
v = iter(v)
|
||||
self.is_empty = v.next()
|
||||
self.v = v.next()
|
||||
self.is_empty = next(v)
|
||||
self.v = next(v)
|
||||
self.x = ValueTuple(v)
|
||||
else:
|
||||
if empty is None:
|
||||
@@ -332,7 +333,7 @@ class Entry(object):
|
||||
try:
|
||||
return Entry(i + j for i,j in zip(self, other))
|
||||
except:
|
||||
print self, other
|
||||
print(self, other)
|
||||
raise
|
||||
def __sub__(self, other):
|
||||
return Entry(i - j for i,j in zip(self, other))
|
||||
@@ -342,7 +343,7 @@ class Entry(object):
|
||||
try:
|
||||
return Entry(other * i for i in self)
|
||||
except:
|
||||
print self, other
|
||||
print(self, other)
|
||||
raise
|
||||
__rmul__ = __mul__
|
||||
def reveal(self):
|
||||
@@ -372,8 +373,8 @@ class RefRAM(object):
|
||||
for t,array in zip(self.entry_type,oram.ram.l)]
|
||||
self.index = index
|
||||
def init_mem(self, empty_entry):
|
||||
print 'init ram'
|
||||
for a,value in zip(self.l, empty_entry.defaults.values()):
|
||||
print('init ram')
|
||||
for a,value in zip(self.l, list(empty_entry.defaults.values())):
|
||||
# don't use threads if n_threads explicitly set to 1
|
||||
a.assign_all(value, n_threads != 1, conv=False)
|
||||
def get_empty_bits(self):
|
||||
@@ -392,14 +393,14 @@ class RefRAM(object):
|
||||
return [Value(self.l[2+index][i], self.l[0][i]) for i in range(self.size)]
|
||||
def __getitem__(self, index):
|
||||
if print_access:
|
||||
print 'get', id(self), index
|
||||
print('get', id(self), index)
|
||||
return Entry(a[index] for a in self.l)
|
||||
def __setitem__(self, index, value):
|
||||
if print_access:
|
||||
print 'set', id(self), index
|
||||
print('set', id(self), index)
|
||||
if not isinstance(value, Entry):
|
||||
raise Exception('entries only please: %s' % str(value))
|
||||
for i,(a,v) in enumerate(zip(self.l, value.values())):
|
||||
for i,(a,v) in enumerate(zip(self.l, list(value.values()))):
|
||||
a[index] = v
|
||||
def __len__(self):
|
||||
return self.size
|
||||
@@ -524,7 +525,7 @@ class RefTrivialORAM(EndRecursiveEviction):
|
||||
self.value_type, self.entry_size = oram.internal_entry_size()
|
||||
self.size = oram.bucket_size
|
||||
def init_mem(self):
|
||||
print 'init trivial oram'
|
||||
print('init trivial oram')
|
||||
self.ram.init_mem(self.empty_entry(apply_type=False))
|
||||
def search(self, read_index):
|
||||
if use_binary_search and self.value_type == sgf2n:
|
||||
@@ -554,7 +555,7 @@ class RefTrivialORAM(EndRecursiveEviction):
|
||||
self.last_index = read_index
|
||||
found, empty = self.search(read_index)
|
||||
entries = [entry for entry in self.ram]
|
||||
prod_entries = map(operator.mul, found, entries)
|
||||
prod_entries = list(map(operator.mul, found, entries))
|
||||
read_value = sum((entry.x.skip(skip) for entry in prod_entries), \
|
||||
empty * empty_entry.x.skip(skip))
|
||||
for i,(entry, prod_entry) in enumerate(zip(entries, prod_entries)):
|
||||
@@ -566,7 +567,7 @@ class RefTrivialORAM(EndRecursiveEviction):
|
||||
def read_and_remove_by_public(self, index):
|
||||
empty_entry = self.empty_entry(False)
|
||||
entries = [entry for entry in self.ram]
|
||||
prod_entries = map(operator.mul, index, entries)
|
||||
prod_entries = list(map(operator.mul, index, entries))
|
||||
read_entry = reduce(operator.add, prod_entries)
|
||||
for i,(entry, prod_entry) in enumerate(zip(entries, prod_entries)):
|
||||
self.ram[i] = entry - prod_entry + index[i] * empty_entry
|
||||
@@ -574,7 +575,7 @@ class RefTrivialORAM(EndRecursiveEviction):
|
||||
@method_block
|
||||
def _read(self, index):
|
||||
found, empty = self.search(index)
|
||||
read_value = sum(map(operator.mul, found, self.ram.get_values()), \
|
||||
read_value = sum(list(map(operator.mul, found, self.ram.get_values())), \
|
||||
empty * self.empty_entry(False).x)
|
||||
return read_value, empty
|
||||
@method_block
|
||||
@@ -583,8 +584,8 @@ class RefTrivialORAM(EndRecursiveEviction):
|
||||
found, not_found = self.search(index)
|
||||
add_here = self.find_first_empty()
|
||||
entries = [entry for entry in self.ram]
|
||||
prod_values = map(operator.mul, found, \
|
||||
(entry.x for entry in entries))
|
||||
prod_values = list(map(operator.mul, found, \
|
||||
(entry.x for entry in entries)))
|
||||
read_value = sum(prod_values, not_found * empty_entry.x)
|
||||
new_value = ValueTuple(new_value) \
|
||||
if isinstance(new_value, (tuple, list)) \
|
||||
@@ -699,15 +700,15 @@ class RefTrivialORAM(EndRecursiveEviction):
|
||||
for k in range(2**(j)):
|
||||
t = k + 2**(j) - 1
|
||||
if k % 2 == 0:
|
||||
M += bit_prods[(t-1)/2] * mult_tree[t]
|
||||
M += bit_prods[(t-1)//2] * mult_tree[t]
|
||||
|
||||
b = 1 - M.equal(0, 40, expand)
|
||||
|
||||
for k in range(2**j):
|
||||
t = k + 2**j - 1
|
||||
if k % 2 == 0:
|
||||
v = bit_prods[(t-1)/2] * b
|
||||
bit_prods[t] = bit_prods[(t-1)/2] - v
|
||||
v = bit_prods[(t-1)//2] * b
|
||||
bit_prods[t] = bit_prods[(t-1)//2] - v
|
||||
else:
|
||||
bit_prods[t] = v
|
||||
return bit_prods[n-1:n-1+self.size], 1 - bit_prods[0]
|
||||
@@ -734,7 +735,7 @@ class RefTrivialORAM(EndRecursiveEviction):
|
||||
print_ln('Bucket overflow')
|
||||
crash()
|
||||
if debug and not sum(add_here) and not new_entry.empty():
|
||||
print self.empty_entry()
|
||||
print(self.empty_entry())
|
||||
raise Exception('no space for %s in %s' % (str(new_entry), str(self)))
|
||||
self.check(new_entry=new_entry, op='add')
|
||||
def pop(self):
|
||||
@@ -746,7 +747,7 @@ class RefTrivialORAM(EndRecursiveEviction):
|
||||
pop_here = [prefix_empty[i+1] - prefix_empty[i] \
|
||||
for i in range(len(self.ram))]
|
||||
entries = [entry for entry in self.ram]
|
||||
prod_entries = map(operator.mul, pop_here, self.ram)
|
||||
prod_entries = list(map(operator.mul, pop_here, self.ram))
|
||||
result = (1 - sum(pop_here)) * empty_entry
|
||||
result = sum(prod_entries, result)
|
||||
for i,(entry, prod_entry) in enumerate(zip(entries, prod_entries)):
|
||||
@@ -980,7 +981,7 @@ class LocalIndexStructure(List):
|
||||
@for_range(init_rounds if init_rounds > 0 else size)
|
||||
def f(i):
|
||||
self.l[0][i] = random_block(entry_size, value_type)
|
||||
print 'index size:', size
|
||||
print('index size:', size)
|
||||
def update(self, index, value, evict=None):
|
||||
read_value = self[index]
|
||||
#print 'read', index, read_value
|
||||
@@ -1005,7 +1006,7 @@ class TreeORAM(AbstractORAM):
|
||||
""" Tree ORAM. """
|
||||
def __init__(self, size, value_type=sint, value_length=1, entry_size=None, \
|
||||
bucket_oram=TrivialORAM, init_rounds=-1):
|
||||
print 'create oram of size', size
|
||||
print('create oram of size', size)
|
||||
self.bucket_oram = bucket_oram
|
||||
# heuristic bucket size
|
||||
delta = 3
|
||||
@@ -1013,9 +1014,9 @@ class TreeORAM(AbstractORAM):
|
||||
# size + 1 for bucket overflow check
|
||||
self.bucket_size = min(int(math.ceil((1 + delta) * k)), size + 1)
|
||||
self.D = log2(max(size / k, 2))
|
||||
print 'bucket size:', self.bucket_size
|
||||
print 'depth:', self.D
|
||||
print 'complexity:', self.bucket_size * (self.D + 1)
|
||||
print('bucket size:', self.bucket_size)
|
||||
print('depth:', self.D)
|
||||
print('complexity:', self.bucket_size * (self.D + 1))
|
||||
self.value_type = value_type
|
||||
if entry_size is not None:
|
||||
self.value_length = len(tuplify(entry_size))
|
||||
@@ -1279,8 +1280,8 @@ class TreeORAM(AbstractORAM):
|
||||
# split into 2 if bucket size can't fit into one field elem
|
||||
if self.bucket_size + Program.prog.security > 128:
|
||||
parity = (empty_positions[i]+1) % 2
|
||||
half = (empty_positions[i]+1 - parity) / 2
|
||||
half_max = self.bucket_size / 2
|
||||
half = (empty_positions[i]+1 - parity) // 2
|
||||
half_max = self.bucket_size // 2
|
||||
|
||||
bits = floatingpoint.B2U(half, half_max, Program.prog.security)[0]
|
||||
bits2 = floatingpoint.B2U(half+parity, half_max, Program.prog.security)[0]
|
||||
@@ -1384,11 +1385,11 @@ def get_parallel(index_size, value_type, value_length):
|
||||
value_size = get_value_size(value_type)
|
||||
if value_type == sint:
|
||||
value_size *= 2
|
||||
res = max(1, min(50 * 32 / (value_length * value_size), \
|
||||
800 * 32 / (value_length * index_size)))
|
||||
res = max(1, min(50 * 32 // (value_length * value_size), \
|
||||
800 * 32 // (value_length * index_size)))
|
||||
if comparison.const_rounds:
|
||||
res = max(1, res / 2)
|
||||
print 'Reading %d buckets in parallel' % res
|
||||
res = max(1, res // 2)
|
||||
print('Reading %d buckets in parallel' % res)
|
||||
return res
|
||||
|
||||
class PackedIndexStructure(object):
|
||||
@@ -1403,7 +1404,7 @@ class PackedIndexStructure(object):
|
||||
self.value_type = value_type
|
||||
for demux_bits in range(max_demux_bits + 1):
|
||||
self.log_entries_per_element = min(log2(size), \
|
||||
int(math.floor(math.log(float(get_value_size(value_type)) / \
|
||||
int(math.floor(math.log(float(get_value_size(value_type)) // \
|
||||
sum(self.entry_size), 2))))
|
||||
self.log_elements_per_block = \
|
||||
max(0, min(demux_bits, log2(size) - \
|
||||
@@ -1423,24 +1424,24 @@ class PackedIndexStructure(object):
|
||||
self.elements_per_entry = len(self.split_sizes)
|
||||
self.log_elements_per_block = log2(self.elements_per_entry)
|
||||
self.log_entries_per_element = -self.log_elements_per_block
|
||||
print 'split sizes:', self.split_sizes
|
||||
print('split sizes:', self.split_sizes)
|
||||
self.log_entries_per_block = \
|
||||
self.log_elements_per_block + self.log_entries_per_element
|
||||
self.elements_per_block = 2**self.log_elements_per_block
|
||||
self.entries_per_element = 2**self.log_entries_per_element
|
||||
self.entries_per_block = 2**self.log_entries_per_block
|
||||
self.used_bits = self.entries_per_element * sum(self.entry_size)
|
||||
real_size = -(-size / self.entries_per_block)
|
||||
print 'packed size:', real_size
|
||||
print 'index size:', size
|
||||
print 'entry size:', self.entry_size
|
||||
print 'log(entries per element):', self.log_entries_per_element
|
||||
print 'entries per element:', self.entries_per_element
|
||||
print 'log(entries per block):', self.log_entries_per_block
|
||||
print 'entries per block:', self.entries_per_block
|
||||
print 'log(elements per block):', self.log_elements_per_block
|
||||
print 'elements per block:', self.elements_per_block
|
||||
print 'used bits:', self.used_bits
|
||||
real_size = -(-size // self.entries_per_block)
|
||||
print('packed size:', real_size)
|
||||
print('index size:', size)
|
||||
print('entry size:', self.entry_size)
|
||||
print('log(entries per element):', self.log_entries_per_element)
|
||||
print('entries per element:', self.entries_per_element)
|
||||
print('log(entries per block):', self.log_entries_per_block)
|
||||
print('entries per block:', self.entries_per_block)
|
||||
print('log(elements per block):', self.log_elements_per_block)
|
||||
print('elements per block:', self.elements_per_block)
|
||||
print('used bits:', self.used_bits)
|
||||
entry_size = [self.used_bits] * self.elements_per_block
|
||||
if real_size > 1:
|
||||
# no need to init underlying ORAM, will be initialized implicitely
|
||||
@@ -1454,10 +1455,10 @@ class PackedIndexStructure(object):
|
||||
self.index_type = self.l.index_type
|
||||
if init_rounds:
|
||||
if init_rounds > 0:
|
||||
real_init_rounds = init_rounds * real_size / size
|
||||
real_init_rounds = init_rounds * real_size // size
|
||||
else:
|
||||
real_init_rounds = real_size
|
||||
print 'packed init rounds:', real_init_rounds
|
||||
print('packed init rounds:', real_init_rounds)
|
||||
@for_range(real_init_rounds)
|
||||
def f(i):
|
||||
if random_init:
|
||||
@@ -1467,7 +1468,7 @@ class PackedIndexStructure(object):
|
||||
self.l[i] = [0] * self.elements_per_block
|
||||
time()
|
||||
print_ln('packed ORAM init %s/%s', i, real_init_rounds)
|
||||
print 'index initialized, size', size
|
||||
print('index initialized, size', size)
|
||||
def translate_index(self, index):
|
||||
""" Bit slicing *index* according parameters. Output is tuple
|
||||
(storage address, index with storage cell, index within
|
||||
@@ -1501,16 +1502,16 @@ class PackedIndexStructure(object):
|
||||
self.block = block
|
||||
self.index_vector = \
|
||||
demux(bit_decompose(self.b, self.pack.log_elements_per_block))
|
||||
self.vector = map(operator.mul, self.index_vector, block)
|
||||
self.vector = list(map(operator.mul, self.index_vector, block))
|
||||
self.element = get_block(sum(self.vector), self.c, \
|
||||
self.pack.entry_size, \
|
||||
self.pack.entries_per_element)
|
||||
return tuple(self.element.get_slice())
|
||||
def write(self, value):
|
||||
self.element.set_slice(value)
|
||||
anti_vector = map(operator.sub, self.block, self.vector)
|
||||
anti_vector = list(map(operator.sub, self.block, self.vector))
|
||||
updated_vector = [self.element.value * i for i in self.index_vector]
|
||||
updated_block = map(operator.add, anti_vector, updated_vector)
|
||||
updated_block = list(map(operator.add, anti_vector, updated_vector))
|
||||
return updated_block
|
||||
class MultiSlicer(object):
|
||||
def __init__(self, pack, index):
|
||||
@@ -1685,7 +1686,7 @@ def test_oram(oram_type, N, value_type=sint, iterations=100):
|
||||
value_type = value_type.get_type(32)
|
||||
index_type = value_type.get_type(log2(N))
|
||||
start_grind()
|
||||
print 'initialized'
|
||||
print('initialized')
|
||||
print_ln('initialized')
|
||||
stop_timer()
|
||||
# synchronize
|
||||
@@ -1718,7 +1719,7 @@ def test_oram(oram_type, N, value_type=sint, iterations=100):
|
||||
def test_oram_access(oram_type, N, value_type=sint, index_size=None, iterations=100):
|
||||
oram = oram_type(N, value_type=value_type, entry_size=32, \
|
||||
init_rounds=0)
|
||||
print 'initialized'
|
||||
print('initialized')
|
||||
print_reg(cint(0), 'init')
|
||||
stop_timer()
|
||||
# synchronize
|
||||
@@ -1731,11 +1732,11 @@ def test_oram_access(oram_type, N, value_type=sint, index_size=None, iterations=
|
||||
def f(i):
|
||||
oram.access(value_type(i % N), value_type(0), value_type(True))
|
||||
oram.access(value_type(i % N), value_type(i % N), value_type(True))
|
||||
print 'first write'
|
||||
print('first write')
|
||||
time()
|
||||
x = oram.access(value_type(i % N), value_type(0), value_type(False))
|
||||
x[0][0].reveal().print_reg('writ')
|
||||
print 'first read'
|
||||
print('first read')
|
||||
# @for_range(iterations)
|
||||
# def f(i):
|
||||
# x = oram.access(value_type(i % N), value_type(0), value_type(False), \
|
||||
@@ -1747,7 +1748,7 @@ def test_oram_access(oram_type, N, value_type=sint, index_size=None, iterations=
|
||||
def test_batch_init(oram_type, N):
|
||||
value_type = sint
|
||||
oram = oram_type(N, value_type)
|
||||
print 'initialized'
|
||||
print('initialized')
|
||||
print_reg(cint(0), 'init')
|
||||
oram.batch_init([value_type(i) for i in range(N)])
|
||||
print_reg(cint(0), 'done')
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
if '_Array' not in dir():
|
||||
from oram import *
|
||||
import permutation
|
||||
from Compiler.oram import *
|
||||
from Compiler import permutation
|
||||
_Array = Array
|
||||
|
||||
import oram
|
||||
from Compiler import oram
|
||||
from functools import reduce
|
||||
|
||||
#import pdb
|
||||
|
||||
@@ -140,7 +141,7 @@ class PathORAM(TreeORAM):
|
||||
bucket_size=2, init_rounds=-1):
|
||||
#if size <= k:
|
||||
# raise CompilerError('ORAM size too small')
|
||||
print 'create oram of size', size
|
||||
print('create oram of size', size)
|
||||
self.bucket_oram = bucket_oram
|
||||
self.bucket_size = bucket_size
|
||||
self.D = log2(size)
|
||||
@@ -240,7 +241,7 @@ class PathORAM(TreeORAM):
|
||||
|
||||
self.state.write(self.value_type(leaf))
|
||||
|
||||
print 'eviction leaf =', leaf
|
||||
print('eviction leaf =', leaf)
|
||||
|
||||
# load the path
|
||||
for i, ram_indices in enumerate(self.bucket_indices_on_path_to(leaf)):
|
||||
@@ -325,7 +326,7 @@ class PathORAM(TreeORAM):
|
||||
|
||||
# at most one 1 in found
|
||||
empty = 1 - sum(found)
|
||||
prod_entries = map(operator.mul, found, entries)
|
||||
prod_entries = list(map(operator.mul, found, entries))
|
||||
read_value = sum((entry.x.skip(skip) for entry in prod_entries), \
|
||||
empty * empty_entry.x.skip(skip))
|
||||
for i,(j, entry, prod_entry) in enumerate(zip(ram_indices, entries, prod_entries)):
|
||||
@@ -528,7 +529,7 @@ class PathORAM(TreeORAM):
|
||||
values = (ValueTuple(x) for x in zip(*self.read_value))
|
||||
not_empty = [1 - x for x in self.read_empty]
|
||||
read_empty = 1 - sum(not_empty)
|
||||
read_value = sum(map(operator.mul, not_empty, values), \
|
||||
read_value = sum(list(map(operator.mul, not_empty, values)), \
|
||||
ValueTuple(0 for i in range(self.value_length)))
|
||||
self.check(u)
|
||||
Program.prog.curr_tape.\
|
||||
@@ -545,7 +546,7 @@ class PathORAM(TreeORAM):
|
||||
yield bucket
|
||||
def bucket_indices_on_path_to(self, leaf):
|
||||
leaf = regint(leaf)
|
||||
yield range(self.bucket_size)
|
||||
yield list(range(self.bucket_size))
|
||||
index = 0
|
||||
for i in range(self.D):
|
||||
index = 2*index + 1 + regint(cint(leaf) & 1)
|
||||
@@ -742,7 +743,7 @@ class PathORAM(TreeORAM):
|
||||
try:
|
||||
self.stash.add(e)
|
||||
except Exception:
|
||||
print self
|
||||
print(self)
|
||||
raise
|
||||
if evict:
|
||||
self.evict()
|
||||
|
||||
@@ -69,7 +69,7 @@ def odd_even_merge(a, comp):
|
||||
odd_even_merge(even, comp)
|
||||
odd_even_merge(odd, comp)
|
||||
a[0] = even[0]
|
||||
for i in range(1, len(a) / 2):
|
||||
for i in range(1, len(a) // 2):
|
||||
a[2*i-1], a[2*i] = cond_swap(odd[i-1], even[i], comp)
|
||||
a[-1] = odd[-1]
|
||||
|
||||
@@ -77,8 +77,8 @@ def odd_even_merge_sort(a, comp=bitwise_comparator):
|
||||
if len(a) == 1:
|
||||
return
|
||||
elif len(a) % 2 == 0:
|
||||
lower = a[:len(a)/2]
|
||||
upper = a[len(a)/2:]
|
||||
lower = a[:len(a)//2]
|
||||
upper = a[len(a)//2:]
|
||||
odd_even_merge_sort(lower, comp)
|
||||
odd_even_merge_sort(upper, comp)
|
||||
a[:] = lower + upper
|
||||
@@ -137,7 +137,7 @@ def random_perm(n):
|
||||
if not Program.prog.options.insecure:
|
||||
raise CompilerError('no secure implementation of Waksman permution, '
|
||||
'use --insecure to activate')
|
||||
a = range(n)
|
||||
a = list(range(n))
|
||||
for i in range(n-1, 0, -1):
|
||||
j = randint(0, i)
|
||||
t = a[i]
|
||||
@@ -155,10 +155,10 @@ def configure_waksman(perm):
|
||||
n = len(perm)
|
||||
if n == 2:
|
||||
return [(perm[0], perm[0])]
|
||||
I = [None] * (n/2)
|
||||
O = [None] * (n/2)
|
||||
p0 = [None] * (n/2)
|
||||
p1 = [None] * (n/2)
|
||||
I = [None] * (n//2)
|
||||
O = [None] * (n//2)
|
||||
p0 = [None] * (n//2)
|
||||
p1 = [None] * (n//2)
|
||||
inv_perm = [0] * n
|
||||
|
||||
for i, p in enumerate(perm):
|
||||
@@ -170,7 +170,7 @@ def configure_waksman(perm):
|
||||
except ValueError:
|
||||
break
|
||||
#print 'j =', j
|
||||
O[j/2] = 0
|
||||
O[j//2] = 0
|
||||
via = 0
|
||||
j0 = j
|
||||
while True:
|
||||
@@ -178,10 +178,10 @@ def configure_waksman(perm):
|
||||
|
||||
i = inv_perm[j]
|
||||
#print ' p0[%d] = %d' % (inv_perm[j]/2, j/2)
|
||||
p0[i/2] = j/2
|
||||
p0[i//2] = j//2
|
||||
|
||||
I[i/2] = i % 2
|
||||
O[j/2] = j % 2
|
||||
I[i//2] = i % 2
|
||||
O[j//2] = j % 2
|
||||
#print ' O[%d] = %d' % (j/2, j % 2)
|
||||
if i % 2 == 1:
|
||||
i -= 1
|
||||
@@ -198,7 +198,7 @@ def configure_waksman(perm):
|
||||
j += 1
|
||||
#j, via = set_swapper(O, i, via, perm)
|
||||
#print ' p1[%d] = %d' % (i/2, perm[i]/2)
|
||||
p1[i/2] = perm[i]/2
|
||||
p1[i//2] = perm[i]//2
|
||||
|
||||
#print ' i = %d, j = %d' %(i,j)
|
||||
if j == j0:
|
||||
@@ -206,8 +206,8 @@ def configure_waksman(perm):
|
||||
if None not in p0 and None not in p1:
|
||||
break
|
||||
|
||||
assert sorted(p0) == range(n/2)
|
||||
assert sorted(p1) == range(n/2)
|
||||
assert sorted(p0) == list(range(n//2))
|
||||
assert sorted(p1) == list(range(n//2))
|
||||
p0_config = configure_waksman(p0)
|
||||
p1_config = configure_waksman(p1)
|
||||
return [I + O] + [a+b for a,b in zip(p0_config, p1_config)]
|
||||
@@ -219,22 +219,22 @@ def waksman(a, config, depth=0, start=0, reverse=False):
|
||||
a[0], a[1] = cond_swap_bit(a[0], a[1], config[depth][start])
|
||||
return
|
||||
|
||||
a0 = [0] * (n/2)
|
||||
a1 = [0] * (n/2)
|
||||
for i in range(n/2):
|
||||
a0 = [0] * (n//2)
|
||||
a1 = [0] * (n//2)
|
||||
for i in range(n//2):
|
||||
if reverse:
|
||||
a0[i], a1[i] = cond_swap_bit(a[2*i], a[2*i+1], config[depth][i + n/2 + start])
|
||||
a0[i], a1[i] = cond_swap_bit(a[2*i], a[2*i+1], config[depth][i + n//2 + start])
|
||||
else:
|
||||
a0[i], a1[i] = cond_swap_bit(a[2*i], a[2*i+1], config[depth][i + start])
|
||||
|
||||
waksman(a0, config, depth+1, start, reverse)
|
||||
waksman(a1, config, depth+1, start + n/2, reverse)
|
||||
waksman(a1, config, depth+1, start + n//2, reverse)
|
||||
|
||||
for i in range(n/2):
|
||||
for i in range(n//2):
|
||||
if reverse:
|
||||
a[2*i], a[2*i+1] = cond_swap_bit(a0[i], a1[i], config[depth][i + start])
|
||||
else:
|
||||
a[2*i], a[2*i+1] = cond_swap_bit(a0[i], a1[i], config[depth][i + n/2 + start])
|
||||
a[2*i], a[2*i+1] = cond_swap_bit(a0[i], a1[i], config[depth][i + n//2 + start])
|
||||
|
||||
|
||||
WAKSMAN_FUNCTIONS = {}
|
||||
@@ -263,11 +263,11 @@ def iter_waksman(a, config, reverse=False):
|
||||
outwards = 1 - inwards
|
||||
|
||||
sizeval = size
|
||||
#for k in range(n/2):
|
||||
@for_range_parallel(200, n/2)
|
||||
#for k in range(n//2):
|
||||
@for_range_parallel(200, n//2)
|
||||
def f(k):
|
||||
j = cint(k) % sizeval
|
||||
i = (cint(k) - j)/sizeval
|
||||
i = (cint(k) - j)//sizeval
|
||||
base = 2*i*sizeval
|
||||
|
||||
in1, in2 = (base+j+j*inwards), (base+j+j*inwards+1*inwards+sizeval*outwards)
|
||||
@@ -297,7 +297,7 @@ def iter_waksman(a, config, reverse=False):
|
||||
# going into middle of network
|
||||
@for_range(logn)
|
||||
def f(i):
|
||||
size.write(n/(2*nblocks))
|
||||
size.write(n//(2*nblocks))
|
||||
conf_address = MemValue(config.address + depth.read()*n)
|
||||
do_round(size, conf_address, a.address, a2.address, 1)
|
||||
|
||||
@@ -307,20 +307,20 @@ def iter_waksman(a, config, reverse=False):
|
||||
nblocks.write(nblocks*2)
|
||||
depth.write(depth+1)
|
||||
|
||||
nblocks.write(nblocks/4)
|
||||
nblocks.write(nblocks//4)
|
||||
depth.write(depth-2)
|
||||
|
||||
# and back out
|
||||
@for_range(logn-1)
|
||||
def f(i):
|
||||
size.write(n/(2*nblocks))
|
||||
size.write(n//(2*nblocks))
|
||||
conf_address = MemValue(config.address + depth.read()*n)
|
||||
do_round(size, conf_address, a.address, a2.address, 0)
|
||||
|
||||
for i in range(n):
|
||||
a[i] = a2[i]
|
||||
|
||||
nblocks.write(nblocks/2)
|
||||
nblocks.write(nblocks//2)
|
||||
depth.write(depth-1)
|
||||
|
||||
## going into middle of network
|
||||
@@ -375,7 +375,7 @@ def config_shuffle(n, value_type):
|
||||
if n & (n-1) != 0:
|
||||
# pad permutation to power of 2
|
||||
m = 2**int(math.ceil(math.log(n, 2)))
|
||||
perm += range(n, m)
|
||||
perm += list(range(n, m))
|
||||
config_bits = configure_waksman(perm)
|
||||
# 2-D array
|
||||
config = Array(len(config_bits) * len(perm), value_type.reg_type)
|
||||
|
||||
@@ -3,15 +3,17 @@ from Compiler.exceptions import *
|
||||
from Compiler.instructions_base import RegType
|
||||
import Compiler.instructions
|
||||
import Compiler.instructions_base
|
||||
import compilerLib
|
||||
import allocator as al
|
||||
from . import compilerLib
|
||||
from . import allocator as al
|
||||
from . import util
|
||||
import random
|
||||
import time
|
||||
import sys, os, errno
|
||||
import inspect
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, deque
|
||||
import itertools
|
||||
import math
|
||||
from functools import reduce
|
||||
|
||||
|
||||
data_types = dict(
|
||||
@@ -50,11 +52,11 @@ class Program(object):
|
||||
self.bit_length = int(options.binary) or int(options.field)
|
||||
if not self.bit_length:
|
||||
self.bit_length = BIT_LENGTHS[param]
|
||||
print 'Default bit length:', self.bit_length
|
||||
print('Default bit length:', self.bit_length)
|
||||
self.security = 40
|
||||
print 'Default security parameter:', self.security
|
||||
print('Default security parameter:', self.security)
|
||||
self.galois_length = int(options.galois)
|
||||
print 'Galois length:', self.galois_length
|
||||
print('Galois length:', self.galois_length)
|
||||
self.schedule = [('start', [])]
|
||||
self.tape_counter = 0
|
||||
self.tapes = []
|
||||
@@ -118,7 +120,7 @@ class Program(object):
|
||||
running[tape] -= 1
|
||||
else:
|
||||
raise CompilerError('Invalid schedule action')
|
||||
res = max(res, sum(running.itervalues()))
|
||||
res = max(res, sum(running.values()))
|
||||
return res
|
||||
|
||||
def init_names(self, args, assemblymode):
|
||||
@@ -129,7 +131,7 @@ class Program(object):
|
||||
else:
|
||||
# assume source is in main SPDZ directory
|
||||
self.programs_dir = sys.path[0] + '/Programs'
|
||||
print 'Compiling program in', self.programs_dir
|
||||
print('Compiling program in', self.programs_dir)
|
||||
|
||||
# create extra directories if needed
|
||||
for dirname in ['Public-Input', 'Bytecode', 'Schedules']:
|
||||
@@ -225,7 +227,7 @@ class Program(object):
|
||||
def read_memory(self, filename):
|
||||
""" Read the clear and shared memory from a file """
|
||||
f = open(filename)
|
||||
n = int(f.next())
|
||||
n = int(next(f))
|
||||
self.mem_c = [0]*n
|
||||
self.mem_s = [0]*n
|
||||
mem = self.mem_c
|
||||
@@ -253,8 +255,8 @@ class Program(object):
|
||||
""" Reset register and memory values. """
|
||||
for tape in self.tapes:
|
||||
tape.reset_registers()
|
||||
self.mem_c = range(USER_MEM + TMP_MEM)
|
||||
self.mem_s = range(USER_MEM + TMP_MEM)
|
||||
self.mem_c = list(range(USER_MEM + TMP_MEM))
|
||||
self.mem_s = list(range(USER_MEM + TMP_MEM))
|
||||
|
||||
def write_bytes(self, outfile=None):
|
||||
""" Write all non-empty threads and schedule to files. """
|
||||
@@ -265,7 +267,7 @@ class Program(object):
|
||||
|
||||
sch_filename = self.programs_dir + '/Schedules/%s.sch' % self.name
|
||||
sch_file = open(sch_filename, 'w')
|
||||
print 'Writing to', sch_filename
|
||||
print('Writing to', sch_filename)
|
||||
sch_file.write(str(self.max_par_tapes()) + '\n')
|
||||
sch_file.write(str(len(nonempty_tapes)) + '\n')
|
||||
sch_file.write(' '.join(tape.name for tape in nonempty_tapes) + '\n')
|
||||
@@ -276,7 +278,7 @@ class Program(object):
|
||||
|
||||
for sch in self.schedule:
|
||||
# schedule may still contain empty tapes: ignore these
|
||||
tapes = filter(lambda x: not x[0].is_empty(), sch[1])
|
||||
tapes = [x for x in sch[1] if not x[0].is_empty()]
|
||||
# no empty line
|
||||
if not tapes:
|
||||
continue
|
||||
@@ -358,7 +360,7 @@ class Program(object):
|
||||
|
||||
def malloc(self, size, mem_type, reg_type=None):
|
||||
""" Allocate memory from the top """
|
||||
if not isinstance(size, (int, long)):
|
||||
if not isinstance(size, int):
|
||||
raise CompilerError('size must be known at compile time')
|
||||
if size == 0:
|
||||
return
|
||||
@@ -374,7 +376,7 @@ class Program(object):
|
||||
addr = self.allocated_mem[mem_type]
|
||||
self.allocated_mem[mem_type] += size
|
||||
if len(str(addr)) != len(str(addr + size)):
|
||||
print "Memory of type '%s' now of size %d" % (mem_type, addr + size)
|
||||
print("Memory of type '%s' now of size %d" % (mem_type, addr + size))
|
||||
self.allocated_mem_blocks[addr,mem_type] = size
|
||||
return addr
|
||||
|
||||
@@ -387,11 +389,11 @@ class Program(object):
|
||||
self.free_mem_blocks[size,mem_type].add(addr)
|
||||
|
||||
def finalize_memory(self):
|
||||
import library
|
||||
from . import library
|
||||
self.curr_tape.start_new_basicblock(None, 'memory-usage')
|
||||
# reset register counter to 0
|
||||
self.curr_tape.init_registers()
|
||||
for mem_type,size in self.allocated_mem.items():
|
||||
for mem_type,size in list(self.allocated_mem.items()):
|
||||
if size:
|
||||
#print "Memory of type '%s' of size %d" % (mem_type, size)
|
||||
if mem_type in self.types:
|
||||
@@ -404,11 +406,11 @@ class Program(object):
|
||||
|
||||
def set_bit_length(self, bit_length):
|
||||
self.bit_length = bit_length
|
||||
print 'Changed bit length for comparisons etc. to', bit_length
|
||||
print('Changed bit length for comparisons etc. to', bit_length)
|
||||
|
||||
def set_security(self, security):
|
||||
self.security = security
|
||||
print 'Changed statistical security for comparison etc. to', security
|
||||
print('Changed statistical security for comparison etc. to', security)
|
||||
|
||||
def optimize_for_gc(self):
|
||||
pass
|
||||
@@ -500,9 +502,9 @@ class Tape:
|
||||
#print 'Basic block %d jumps to %d (%d)' % (next_block_index, jump_index, offset)
|
||||
|
||||
def purge(self):
|
||||
relevant = lambda inst: inst.add_usage.__func__ is not \
|
||||
Compiler.instructions_base.Instruction.add_usage.__func__
|
||||
self.usage_instructions = filter(relevant, self.instructions)
|
||||
relevant = lambda inst: inst.add_usage is not \
|
||||
Compiler.instructions_base.Instruction.add_usage
|
||||
self.usage_instructions = list(filter(relevant, self.instructions))
|
||||
del self.instructions
|
||||
del self.defined_registers
|
||||
self.purged = True
|
||||
@@ -568,8 +570,8 @@ class Tape:
|
||||
def unpurged(function):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if self.purged:
|
||||
print '%s called on purged block %s, ignoring' % \
|
||||
(function.__name__, self.name)
|
||||
print('%s called on purged block %s, ignoring' % \
|
||||
(function.__name__, self.name))
|
||||
return
|
||||
return function(self, *args, **kwargs)
|
||||
return wrapper
|
||||
@@ -577,13 +579,13 @@ class Tape:
|
||||
@unpurged
|
||||
def optimize(self, options):
|
||||
if len(self.basicblocks) == 0:
|
||||
print 'Tape %s is empty' % self.name
|
||||
print('Tape %s is empty' % self.name)
|
||||
return
|
||||
|
||||
if self.if_states:
|
||||
raise CompilerError('Unclosed if/else blocks')
|
||||
|
||||
print 'Processing tape', self.name, 'with %d blocks' % len(self.basicblocks)
|
||||
print('Processing tape', self.name, 'with %d blocks' % len(self.basicblocks))
|
||||
|
||||
for block in self.basicblocks:
|
||||
al.determine_scope(block, options)
|
||||
@@ -593,38 +595,38 @@ class Tape:
|
||||
if (options.merge_opens and self.merge_opens) or options.dead_code_elimination:
|
||||
for i,block in enumerate(self.basicblocks):
|
||||
if len(block.instructions) > 0:
|
||||
print 'Processing basic block %s, %d/%d, %d instructions' % \
|
||||
print('Processing basic block %s, %d/%d, %d instructions' % \
|
||||
(block.name, i, len(self.basicblocks), \
|
||||
len(block.instructions))
|
||||
len(block.instructions)))
|
||||
# the next call is necessary for allocation later even without merging
|
||||
merger = al.Merger(block, options, \
|
||||
tuple(self.program.to_merge))
|
||||
if options.dead_code_elimination:
|
||||
if len(block.instructions) > 10000:
|
||||
print 'Eliminate dead code...'
|
||||
print('Eliminate dead code...')
|
||||
merger.eliminate_dead_code()
|
||||
if options.merge_opens and self.merge_opens:
|
||||
if len(block.instructions) == 0:
|
||||
block.used_from_scope = set()
|
||||
block.defined_registers = set()
|
||||
block.used_from_scope = util.set_by_id()
|
||||
block.defined_registers = util.set_by_id()
|
||||
continue
|
||||
if len(block.instructions) > 10000:
|
||||
print 'Merging instructions...'
|
||||
print('Merging instructions...')
|
||||
numrounds = merger.longest_paths_merge()
|
||||
block.n_rounds = numrounds
|
||||
block.n_to_merge = len(merger.open_nodes)
|
||||
if numrounds > 0:
|
||||
print 'Program requires %d rounds of communication' % numrounds
|
||||
print('Program requires %d rounds of communication' % numrounds)
|
||||
if merger.counter:
|
||||
print 'Block requires', \
|
||||
print('Block requires', \
|
||||
', '.join('%d %s' % (y, x.__name__) \
|
||||
for x, y in merger.counter.items())
|
||||
for x, y in list(merger.counter.items())))
|
||||
# free memory
|
||||
merger = None
|
||||
if options.dead_code_elimination:
|
||||
block.instructions = filter(lambda x: x is not None, block.instructions)
|
||||
block.instructions = [x for x in block.instructions if x is not None]
|
||||
if not (options.merge_opens and self.merge_opens):
|
||||
print 'Not merging instructions in tape %s' % self.name
|
||||
print('Not merging instructions in tape %s' % self.name)
|
||||
|
||||
# add jumps
|
||||
offset = 0
|
||||
@@ -640,39 +642,44 @@ class Tape:
|
||||
block.adjust_return()
|
||||
|
||||
# now remove any empty blocks (must be done after setting jumps)
|
||||
self.basicblocks = filter(lambda x: len(x.instructions) != 0, self.basicblocks)
|
||||
self.basicblocks = [x for x in self.basicblocks if len(x.instructions) != 0]
|
||||
|
||||
# allocate registers
|
||||
reg_counts = self.count_regs()
|
||||
if not options.noreallocate:
|
||||
if self.program.verbose:
|
||||
print 'Tape register usage:', dict(reg_counts)
|
||||
print 'modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp])
|
||||
print 'GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N])
|
||||
print 'Re-allocating...'
|
||||
print('Tape register usage:', dict(reg_counts))
|
||||
print('modp: %d clear, %d secret' % (reg_counts[RegType.ClearModp], reg_counts[RegType.SecretModp]))
|
||||
print('GF2N: %d clear, %d secret' % (reg_counts[RegType.ClearGF2N], reg_counts[RegType.SecretGF2N]))
|
||||
print('Re-allocating...')
|
||||
allocator = al.StraightlineAllocator(REG_MAX)
|
||||
def alloc_loop(block):
|
||||
def alloc(block):
|
||||
for reg in sorted(block.used_from_scope,
|
||||
key=lambda x: (x.reg_type, x.i)):
|
||||
allocator.alloc_reg(reg, block.alloc_pool)
|
||||
for child in block.children:
|
||||
if child.instructions:
|
||||
alloc_loop(child)
|
||||
def alloc_loop(block):
|
||||
left = deque([block])
|
||||
while left:
|
||||
block = left.popleft()
|
||||
alloc(block)
|
||||
for child in block.children:
|
||||
if child.instructions:
|
||||
left.append(child)
|
||||
for i,block in enumerate(reversed(self.basicblocks)):
|
||||
if len(block.instructions) > 10000:
|
||||
print 'Allocating %s, %d/%d' % \
|
||||
(block.name, i, len(self.basicblocks))
|
||||
print('Allocating %s, %d/%d' % \
|
||||
(block.name, i, len(self.basicblocks)))
|
||||
if block.exit_condition is not None:
|
||||
jump = block.exit_condition.get_relative_jump()
|
||||
if isinstance(jump, (int,long)) and jump < 0 and \
|
||||
if isinstance(jump, int) and jump < 0 and \
|
||||
block.exit_block.scope is not None:
|
||||
alloc_loop(block.exit_block.scope)
|
||||
allocator.process(block.instructions, block.alloc_pool)
|
||||
|
||||
# offline data requirements
|
||||
print 'Compile offline data requirements...'
|
||||
print('Compile offline data requirements...')
|
||||
self.req_num = self.req_tree.aggregate()
|
||||
print 'Tape requires', self.req_num
|
||||
print('Tape requires', self.req_num)
|
||||
for req,num in sorted(self.req_num.items()):
|
||||
if num == float('inf') or num >= 2 ** 32:
|
||||
num = -1
|
||||
@@ -706,8 +713,8 @@ class Tape:
|
||||
Compiler.instructions.reqbl(bl,
|
||||
add_to_prog=False))
|
||||
if self.program.verbose:
|
||||
print 'Tape requires prime bit length', self.req_bit_length['p']
|
||||
print 'Tape requires galois bit length', self.req_bit_length['2']
|
||||
print('Tape requires prime bit length', self.req_bit_length['p'])
|
||||
print('Tape requires galois bit length', self.req_bit_length['2'])
|
||||
|
||||
@unpurged
|
||||
def _get_instructions(self):
|
||||
@@ -722,12 +729,12 @@ class Tape:
|
||||
@unpurged
|
||||
def get_bytes(self):
|
||||
""" Get the byte encoding of the program as an actual string of bytes. """
|
||||
return "".join(str(i.get_bytes()) for i in self._get_instructions() if i is not None)
|
||||
return b"".join(i.get_bytes() for i in self._get_instructions() if i is not None)
|
||||
|
||||
@unpurged
|
||||
def write_encoding(self, filename):
|
||||
""" Write the readable encoding to a file. """
|
||||
print 'Writing to', filename
|
||||
print('Writing to', filename)
|
||||
f = open(filename, 'w')
|
||||
for line in self.get_encoding():
|
||||
f.write(str(line) + '\n')
|
||||
@@ -736,7 +743,7 @@ class Tape:
|
||||
@unpurged
|
||||
def write_str(self, filename):
|
||||
""" Write the sequence of instructions to a file. """
|
||||
print 'Writing to', filename
|
||||
print('Writing to', filename)
|
||||
f = open(filename, 'w')
|
||||
n = 0
|
||||
for block in self.basicblocks:
|
||||
@@ -756,8 +763,8 @@ class Tape:
|
||||
filename += '.bc'
|
||||
if not 'Bytecode' in filename:
|
||||
filename = self.program.programs_dir + '/Bytecode/' + filename
|
||||
print 'Writing to', filename
|
||||
f = open(filename, 'w')
|
||||
print('Writing to', filename)
|
||||
f = open(filename, 'wb')
|
||||
f.write(self.get_bytes())
|
||||
f.close()
|
||||
|
||||
@@ -785,9 +792,9 @@ class Tape:
|
||||
super(Tape.ReqNum, self).__init__(lambda: 0, init)
|
||||
def __add__(self, other):
|
||||
res = Tape.ReqNum()
|
||||
for i,count in self.items():
|
||||
for i,count in list(self.items()):
|
||||
res[i] += count
|
||||
for i,count in other.items():
|
||||
for i,count in list(other.items()):
|
||||
res[i] += count
|
||||
return res
|
||||
def __mul__(self, other):
|
||||
@@ -798,7 +805,7 @@ class Tape:
|
||||
__rmul__ = __mul__
|
||||
def set_all(self, value):
|
||||
if value == float('inf') and self['all', 'inv'] > 0:
|
||||
print 'Going to unknown from %s' % self
|
||||
print('Going to unknown from %s' % self)
|
||||
res = Tape.ReqNum()
|
||||
for i in self:
|
||||
res[i] = value
|
||||
@@ -811,14 +818,14 @@ class Tape:
|
||||
res[i] = max(self[i], other[i])
|
||||
return res
|
||||
def cost(self):
|
||||
return sum(num * COST[req[0]][req[1]] for req,num in self.items() \
|
||||
return sum(num * COST[req[0]][req[1]] for req,num in list(self.items()) \
|
||||
if req[1] != 'input')
|
||||
def __str__(self):
|
||||
return ", ".join('%s inputs in %s from player %d' \
|
||||
% (num, req[0], req[2]) \
|
||||
if req[1] == 'input' \
|
||||
else '%s %ss in %s' % (num, req[1], req[0]) \
|
||||
for req,num in self.items())
|
||||
for req,num in list(self.items()))
|
||||
def __repr__(self):
|
||||
return repr(dict(self))
|
||||
|
||||
@@ -853,8 +860,8 @@ class Tape:
|
||||
n_rounds = res['all', 'round']
|
||||
n_invs = res['all', 'inv']
|
||||
if (n_invs / n_rounds) * 1000 < n_reps:
|
||||
print self.nodes[0].blocks[0].name, 'blowing up rounds: ', \
|
||||
'(%d / %d) ** 3 < %d' % (n_rounds, n_reps, n_invs)
|
||||
print(self.nodes[0].blocks[0].name, 'blowing up rounds: ', \
|
||||
'(%d / %d) ** 3 < %d' % (n_rounds, n_reps, n_invs))
|
||||
except:
|
||||
pass
|
||||
return res
|
||||
@@ -892,7 +899,7 @@ class Tape:
|
||||
|
||||
The 'value' property is for emulation.
|
||||
"""
|
||||
__slots__ = ["reg_type", "program", "i", "value", "_is_active", \
|
||||
__slots__ = ["reg_type", "program", "i", "_is_active", \
|
||||
"size", "vector", "vectorbase", "caller", \
|
||||
"can_eliminate"]
|
||||
|
||||
@@ -925,7 +932,7 @@ class Tape:
|
||||
else:
|
||||
self.caller = None
|
||||
if self.i % 1000000 == 0 and self.i > 0:
|
||||
print "Initialized %d registers at" % self.i, time.asctime()
|
||||
print("Initialized %d registers at" % self.i, time.asctime())
|
||||
|
||||
def set_size(self, size):
|
||||
if self.size == size:
|
||||
|
||||
@@ -2,11 +2,12 @@ from Compiler.program import Tape
|
||||
from Compiler.exceptions import *
|
||||
from Compiler.instructions import *
|
||||
from Compiler.instructions_base import *
|
||||
from floatingpoint import two_power
|
||||
import comparison, floatingpoint
|
||||
from .floatingpoint import two_power
|
||||
from . import comparison, floatingpoint
|
||||
import math
|
||||
import util
|
||||
from . import util
|
||||
import operator
|
||||
from functools import reduce
|
||||
|
||||
|
||||
class ClientMessageType:
|
||||
@@ -127,15 +128,15 @@ class _number(object):
|
||||
return self * self
|
||||
|
||||
def __add__(self, other):
|
||||
if other is 0 or other is 0L:
|
||||
if other is 0:
|
||||
return self
|
||||
else:
|
||||
return self.add(other)
|
||||
|
||||
def __mul__(self, other):
|
||||
if other is 0 or other is 0L:
|
||||
if other is 0:
|
||||
return 0
|
||||
elif other is 1 or other is 1L:
|
||||
elif other is 1:
|
||||
return self
|
||||
else:
|
||||
return self.mul(other)
|
||||
@@ -301,7 +302,7 @@ class _register(Tape.Register, _number, _structure):
|
||||
if isinstance(val, (tuple, list)):
|
||||
size = len(val)
|
||||
super(_register, self).__init__(reg_type, program.curr_tape, size=size)
|
||||
if isinstance(val, (int, long)):
|
||||
if isinstance(val, int):
|
||||
self.load_int(val)
|
||||
elif isinstance(val, (tuple, list)):
|
||||
for i, x in enumerate(val):
|
||||
@@ -374,7 +375,7 @@ class _clear(_register):
|
||||
res = self.prep_res(other)
|
||||
if isinstance(other, cls):
|
||||
c_inst(res, self, other)
|
||||
elif isinstance(other, (int, long)):
|
||||
elif isinstance(other, int):
|
||||
if self.in_immediate_range(other):
|
||||
ci_inst(res, self, other)
|
||||
else:
|
||||
@@ -392,7 +393,7 @@ class _clear(_register):
|
||||
def coerce_op(self, other, inst, reverse=False):
|
||||
cls = self.__class__
|
||||
res = cls()
|
||||
if isinstance(other, (int, long)):
|
||||
if isinstance(other, int):
|
||||
other = cls(other)
|
||||
elif not isinstance(other, cls):
|
||||
return NotImplemented
|
||||
@@ -414,14 +415,14 @@ class _clear(_register):
|
||||
def __rsub__(self, other):
|
||||
return self.clear_op(other, subc, subcfi, True)
|
||||
|
||||
def __div__(self, other):
|
||||
def __truediv__(self, other):
|
||||
return self.clear_op(other, divc, divci)
|
||||
|
||||
def __rdiv__(self, other):
|
||||
def __rtruediv__(self, other):
|
||||
return self.coerce_op(other, divc, True)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, (_clear,int,long)):
|
||||
if isinstance(other, (_clear,int)):
|
||||
return regint(self) == other
|
||||
else:
|
||||
return NotImplemented
|
||||
@@ -493,12 +494,12 @@ class cint(_clear, _int):
|
||||
ldi(self, val)
|
||||
else:
|
||||
max = 2**31 - 1
|
||||
sign = abs(val) / val
|
||||
sign = abs(val) // val
|
||||
val = abs(val)
|
||||
chunks = []
|
||||
while val:
|
||||
mod = val % max
|
||||
val = (val - mod) / max
|
||||
val = (val - mod) // max
|
||||
chunks.append(mod)
|
||||
sum = cint(sign * chunks.pop())
|
||||
for i,chunk in enumerate(reversed(chunks)):
|
||||
@@ -520,13 +521,13 @@ class cint(_clear, _int):
|
||||
return self.coerce_op(other, modc, True)
|
||||
|
||||
def __lt__(self, other):
|
||||
if isinstance(other, (type(self),int,long)):
|
||||
if isinstance(other, (type(self),int)):
|
||||
return regint(self) < other
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __gt__(self, other):
|
||||
if isinstance(other, (type(self),int,long)):
|
||||
if isinstance(other, (type(self),int)):
|
||||
return regint(self) > other
|
||||
else:
|
||||
return NotImplemented
|
||||
@@ -537,6 +538,23 @@ class cint(_clear, _int):
|
||||
def __ge__(self, other):
|
||||
return 1 - (self < other)
|
||||
|
||||
@vectorize
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, (_clear, int)):
|
||||
return NotImplemented
|
||||
res = 1
|
||||
remaining = program.bit_length
|
||||
while remaining > 0:
|
||||
if isinstance(other, cint):
|
||||
o = other.to_regint(min(remaining, 64))
|
||||
else:
|
||||
o = other % 2 ** 64
|
||||
res *= (self.to_regint(min(remaining, 64)) == o)
|
||||
self >>= 64
|
||||
other >>= 64
|
||||
remaining -= 64
|
||||
return res
|
||||
|
||||
def __lshift__(self, other):
|
||||
return self.clear_op(other, shlc, shlci)
|
||||
|
||||
@@ -683,7 +701,7 @@ class cgf2n(_clear, _gf2n):
|
||||
def bit_decompose(self, bit_length=None, step=None):
|
||||
bit_length = bit_length or program.galois_length
|
||||
step = step or 1
|
||||
res = [type(self)() for _ in range(bit_length / step)]
|
||||
res = [type(self)() for _ in range(bit_length // step)]
|
||||
gbitdec(self, step, *res)
|
||||
return res
|
||||
|
||||
@@ -817,12 +835,15 @@ class regint(_register, _int):
|
||||
def __neg__(self):
|
||||
return 0 - self
|
||||
|
||||
def __div__(self, other):
|
||||
def __floordiv__(self, other):
|
||||
return self.int_op(other, divint)
|
||||
|
||||
def __rdiv__(self, other):
|
||||
def __rfloordiv__(self, other):
|
||||
return self.int_op(other, divint, True)
|
||||
|
||||
__truediv__ = __floordiv__
|
||||
__rtruediv__ = __rfloordiv__
|
||||
|
||||
def __mod__(self, other):
|
||||
return self - (self / other) * other
|
||||
|
||||
@@ -851,13 +872,13 @@ class regint(_register, _int):
|
||||
return 1 - (self < other)
|
||||
|
||||
def __lshift__(self, other):
|
||||
if isinstance(other, (int, long)):
|
||||
if isinstance(other, int):
|
||||
return self * 2**other
|
||||
else:
|
||||
return regint(cint(self) << other)
|
||||
|
||||
def __rshift__(self, other):
|
||||
if isinstance(other, (int, long)):
|
||||
if isinstance(other, int):
|
||||
return self / 2**other
|
||||
else:
|
||||
return regint(cint(self) >> other)
|
||||
@@ -911,6 +932,24 @@ class regint(_register, _int):
|
||||
def print_if(self, string):
|
||||
cint(self).print_if(string)
|
||||
|
||||
class localint(object):
|
||||
""" Local integer that must prevented from leaking into the secure
|
||||
computation. Uses regint internally. """
|
||||
|
||||
def __init__(self, value=None):
|
||||
self._v = regint(value)
|
||||
self.size = 1
|
||||
|
||||
def output(self):
|
||||
self._v.print_reg_plain()
|
||||
|
||||
__lt__ = lambda self, other: localint(self._v < other)
|
||||
__le__ = lambda self, other: localint(self._v <= other)
|
||||
__gt__ = lambda self, other: localint(self._v > other)
|
||||
__ge__ = lambda self, other: localint(self._v >= other)
|
||||
__eq__ = lambda self, other: localint(self._v == other)
|
||||
__ne__ = lambda self, other: localint(self._v != other)
|
||||
|
||||
class _secret(_register):
|
||||
__slots__ = []
|
||||
|
||||
@@ -996,11 +1035,11 @@ class _secret(_register):
|
||||
def matrix_mul(cls, A, B, n, res_params=None):
|
||||
assert len(A) % n == 0
|
||||
assert len(B) % n == 0
|
||||
size = len(A) * len(B) / n**2
|
||||
size = len(A) * len(B) // n**2
|
||||
res = cls(size=size)
|
||||
n_rows = len(A) / n
|
||||
n_cols = len(B) / n
|
||||
dotprods(*sum(([res[j], [A[j / n_cols * n + k] for k in range(n)],
|
||||
n_rows = len(A) // n
|
||||
n_cols = len(B) // n
|
||||
dotprods(*sum(([res[j], [A[j // n_cols * n + k] for k in range(n)],
|
||||
[B[k * n_cols + j % n_cols] for k in range(n)]]
|
||||
for j in range(size)), []))
|
||||
return res
|
||||
@@ -1054,7 +1093,7 @@ class _secret(_register):
|
||||
m_inst(res, other, self)
|
||||
else:
|
||||
m_inst(res, self, other)
|
||||
elif isinstance(other, (int, long)):
|
||||
elif isinstance(other, int):
|
||||
if self.clear_type.in_immediate_range(other):
|
||||
si_inst(res, self, other)
|
||||
else:
|
||||
@@ -1086,11 +1125,11 @@ class _secret(_register):
|
||||
return self.secret_op(other, subs, submr, subsfi, True)
|
||||
|
||||
@vectorize
|
||||
def __div__(self, other):
|
||||
def __truediv__(self, other):
|
||||
return self * (self.clear_type(1) / other)
|
||||
|
||||
@vectorize
|
||||
def __rdiv__(self, other):
|
||||
def __rtruediv__(self, other):
|
||||
a,b = self.get_random_inverse()
|
||||
return other * a / (a * self).reveal()
|
||||
|
||||
@@ -1253,7 +1292,7 @@ class sint(_secret, _int):
|
||||
|
||||
@vectorize
|
||||
def __mod__(self, modulus):
|
||||
if isinstance(modulus, (int, long)):
|
||||
if isinstance(modulus, int):
|
||||
l = math.log(modulus, 2)
|
||||
if 2**int(round(l)) == modulus:
|
||||
return self.mod2m(int(l))
|
||||
@@ -1405,7 +1444,7 @@ class sgf2n(_secret, _gf2n):
|
||||
return self ^ cgf2n(2**program.galois_length - 1)
|
||||
|
||||
def __xor__(self, other):
|
||||
if other is 0 or other is 0L:
|
||||
if other is 0:
|
||||
return self
|
||||
else:
|
||||
return super(sgf2n, self).add(other)
|
||||
@@ -1414,7 +1453,7 @@ class sgf2n(_secret, _gf2n):
|
||||
|
||||
@vectorize
|
||||
def __and__(self, other):
|
||||
if isinstance(other, (int, long)):
|
||||
if isinstance(other, int):
|
||||
other_bits = [(other >> i) & 1 \
|
||||
for i in range(program.galois_length)]
|
||||
else:
|
||||
@@ -1515,7 +1554,7 @@ class _bitint(object):
|
||||
else:
|
||||
pre_op = floatingpoint.PreOpL
|
||||
if d:
|
||||
carries = zip(*pre_op(carry, [(0, carry_in)] + d))[1]
|
||||
carries = list(zip(*pre_op(carry, [(0, carry_in)] + d)))[1]
|
||||
else:
|
||||
carries = []
|
||||
res = lower + cls.sum_from_carries(a, b, carries)
|
||||
@@ -1539,7 +1578,7 @@ class _bitint(object):
|
||||
for k in range(m, -1, -1):
|
||||
if sum(range(m, k - 1, -1)) + 1 >= n:
|
||||
break
|
||||
blocks = range(m, k, -1)
|
||||
blocks = list(range(m, k, -1))
|
||||
blocks.append(n - sum(blocks))
|
||||
blocks.reverse()
|
||||
#print 'blocks:', blocks
|
||||
@@ -1597,9 +1636,9 @@ class _bitint(object):
|
||||
|
||||
@staticmethod
|
||||
def get_highest_different_bits(a, b, index):
|
||||
diff = [ai + bi for (ai,bi) in reversed(zip(a,b))]
|
||||
diff = [ai + bi for (ai,bi) in reversed(list(zip(a,b)))]
|
||||
preor = floatingpoint.PreOR(diff, raw=True)
|
||||
highest_diff = [x - y for (x,y) in reversed(zip(preor, [0] + preor))]
|
||||
highest_diff = [x - y for (x,y) in reversed(list(zip(preor, [0] + preor)))]
|
||||
raw = sum(map(operator.mul, highest_diff, (a,b)[index]))
|
||||
return raw.bit_decompose()[0]
|
||||
|
||||
@@ -1622,7 +1661,7 @@ class _bitint(object):
|
||||
if type(other) == self.bin_type:
|
||||
raise CompilerError('Unclear multiplication')
|
||||
self_bits = self.bit_decompose()
|
||||
if isinstance(other, (int, long)):
|
||||
if isinstance(other, int):
|
||||
other_bits = util.bit_decompose(other, self.n_bits)
|
||||
bit_matrix = [[x * y for y in self_bits] for x in other_bits]
|
||||
else:
|
||||
@@ -1644,8 +1683,8 @@ class _bitint(object):
|
||||
|
||||
@classmethod
|
||||
def wallace_tree_from_matrix(cls, bit_matrix, get_carry=True):
|
||||
columns = [filter(None, (bit_matrix[j][i-j] \
|
||||
for j in range(min(len(bit_matrix), i + 1)))) \
|
||||
columns = [[_f for _f in (bit_matrix[j][i-j] \
|
||||
for j in range(min(len(bit_matrix), i + 1))) if _f] \
|
||||
for i in range(len(bit_matrix[0]))]
|
||||
return cls.wallace_tree_from_columns(columns, get_carry)
|
||||
|
||||
@@ -1671,7 +1710,7 @@ class _bitint(object):
|
||||
columns = new_columns[:-1]
|
||||
for col in columns:
|
||||
col.extend([0] * (2 - len(col)))
|
||||
return self.bit_adder(*zip(*columns))
|
||||
return self.bit_adder(*list(zip(*columns)))
|
||||
|
||||
@classmethod
|
||||
def wallace_tree(cls, rows):
|
||||
@@ -1685,17 +1724,17 @@ class _bitint(object):
|
||||
d = [(1 + ai + bi, (1 - ai) * bi) for (ai,bi) in zip(a,b)]
|
||||
borrow = lambda y,x,*args: \
|
||||
(x[0] * y[0], 1 - (1 - x[1]) * (1 - x[0] * y[1]))
|
||||
borrows = (0,) + zip(*floatingpoint.PreOpL(borrow, d))[1]
|
||||
borrows = (0,) + list(zip(*floatingpoint.PreOpL(borrow, d)))[1]
|
||||
return self.compose(ai + bi + borrow \
|
||||
for (ai,bi,borrow) in zip(a,b,borrows))
|
||||
|
||||
def __rsub__(self, other):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __div__(self, other):
|
||||
def __truediv__(self, other):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __rdiv__(self, other):
|
||||
def __truerdiv__(self, other):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __lshift__(self, other):
|
||||
@@ -1953,7 +1992,7 @@ class cfix(_number, _structure):
|
||||
""" Clear fixed point type. """
|
||||
__slots__ = ['value', 'f', 'k', 'size']
|
||||
reg_type = 'c'
|
||||
scalars = (int, long, float, regint)
|
||||
scalars = (int, float, regint)
|
||||
@classmethod
|
||||
def set_precision(cls, f, k = None):
|
||||
# k is the whole bitlength of fixed point
|
||||
@@ -1978,7 +2017,7 @@ class cfix(_number, _structure):
|
||||
if n == 1:
|
||||
return cfix(cint_inputs)
|
||||
else:
|
||||
return map(cfix, cint_inputs)
|
||||
return list(map(cfix, cint_inputs))
|
||||
|
||||
@vectorize
|
||||
def write_to_socket(self, client_id, message_type=ClientMessageType.NoType):
|
||||
@@ -1990,7 +2029,7 @@ class cfix(_number, _structure):
|
||||
""" Send a list of cfix values to socket. Values are sent as bit shifted cints. """
|
||||
def cfix_to_cint(fix_val):
|
||||
return cint(fix_val.v)
|
||||
cint_values = map(cfix_to_cint, values)
|
||||
cint_values = list(map(cfix_to_cint, values))
|
||||
writesocketc(client_id, message_type, *cint_values)
|
||||
|
||||
@staticmethod
|
||||
@@ -2152,7 +2191,7 @@ class cfix(_number, _structure):
|
||||
raise NotImplementedError
|
||||
|
||||
@vectorize
|
||||
def __div__(self, other):
|
||||
def __truediv__(self, other):
|
||||
other = parse_type(other)
|
||||
if isinstance(other, cfix):
|
||||
return cfix(library.cint_cint_division(self.v, other.v, self.k, self.f))
|
||||
@@ -2190,7 +2229,7 @@ class _single(_number, _structure):
|
||||
""" Securely obtain shares of n values input by a client.
|
||||
Assumes client has already run bit shift to convert fixed point to integer."""
|
||||
sint_inputs = cls.int_type.receive_from_client(n, client_id, ClientMessageType.TripleShares)
|
||||
return map(cls, sint_inputs)
|
||||
return list(map(cls, sint_inputs))
|
||||
|
||||
@vectorized_classmethod
|
||||
def load_mem(cls, address, mem_type=None):
|
||||
@@ -2333,7 +2372,7 @@ class _fix(_single):
|
||||
|
||||
@classmethod
|
||||
def coerce(cls, other):
|
||||
if isinstance(other, (_fix, cfix)):
|
||||
if isinstance(other, (_fix, cls.clear_type)):
|
||||
return other
|
||||
else:
|
||||
return cls.conv(other)
|
||||
@@ -2402,7 +2441,7 @@ class _fix(_single):
|
||||
|
||||
@vectorize
|
||||
def mul(self, other):
|
||||
if isinstance(other, (sint, cint, regint, int, long)):
|
||||
if isinstance(other, (sint, cint, regint, int)):
|
||||
return self._new(self.v * other, k=self.k, f=self.f)
|
||||
elif isinstance(other, float):
|
||||
if int(other) == other:
|
||||
@@ -2413,13 +2452,11 @@ class _fix(_single):
|
||||
f = self.f
|
||||
while v % 2 == 0:
|
||||
f -= 1
|
||||
v /= 2
|
||||
v //= 2
|
||||
k = len(bin(abs(v))) - 1
|
||||
other = cfix(cint(v))
|
||||
other.f = f
|
||||
other.k = k
|
||||
other = self.multipliable(v, k, f)
|
||||
other = self.coerce(other)
|
||||
if isinstance(other, (_fix, cfix)):
|
||||
if isinstance(other, (_fix, self.clear_type)):
|
||||
val = self.v.TruncMul(other.v, self.k + other.k, other.f,
|
||||
self.kappa,
|
||||
self.round_nearest)
|
||||
@@ -2438,7 +2475,7 @@ class _fix(_single):
|
||||
return type(self)(-self.v)
|
||||
|
||||
@vectorize
|
||||
def __div__(self, other):
|
||||
def __truediv__(self, other):
|
||||
other = self.coerce(other)
|
||||
if isinstance(other, _fix):
|
||||
return type(self)(library.FPDiv(self.v, other.v, self.k, self.f,
|
||||
@@ -2450,7 +2487,7 @@ class _fix(_single):
|
||||
raise TypeError('Incompatible fixed point types in division')
|
||||
|
||||
@vectorize
|
||||
def __rdiv__(self, other):
|
||||
def __rtruediv__(self, other):
|
||||
return self.coerce(other) / self
|
||||
|
||||
@vectorize
|
||||
@@ -2497,6 +2534,10 @@ class sfix(_fix):
|
||||
def unreduced(self, v, other=None, res_params=None, n_summands=1):
|
||||
return unreduced_sfix(v, self.k * 2, self.f, self.kappa)
|
||||
|
||||
@staticmethod
|
||||
def multipliable(v, k, f):
|
||||
return cfix(cint.conv(v), k, f)
|
||||
|
||||
class unreduced_sfix(_single):
|
||||
int_type = sint
|
||||
|
||||
@@ -2511,7 +2552,7 @@ class unreduced_sfix(_single):
|
||||
self.kappa = kappa
|
||||
|
||||
def __add__(self, other):
|
||||
if other is 0 or other is 0L:
|
||||
if other is 0:
|
||||
return self
|
||||
assert self.k == other.k
|
||||
assert self.m == other.m
|
||||
@@ -2643,7 +2684,7 @@ class _unreduced_squant(object):
|
||||
self.res_params = res_params or params[0]
|
||||
|
||||
def __add__(self, other):
|
||||
if other is 0 or other is 0L:
|
||||
if other is 0:
|
||||
return self
|
||||
assert self.params == other.params
|
||||
assert self.res_params == other.res_params
|
||||
@@ -2807,10 +2848,10 @@ class sfloat(_number, _structure):
|
||||
v = int(round(abs(v) * 2 ** (-p)))
|
||||
if v == 2 ** vlen:
|
||||
p += 1
|
||||
v /= 2
|
||||
v //= 2
|
||||
z = 0
|
||||
if p < -2 ** (plen - 1):
|
||||
print 'Warning: %e truncated to zero' % vv
|
||||
print('Warning: %e truncated to zero' % vv)
|
||||
v, p, z = 0, 0, 1
|
||||
if p >= 2 ** (plen - 1):
|
||||
raise CompilerError('Cannot convert %s to float ' \
|
||||
@@ -2950,8 +2991,8 @@ class sfloat(_number, _structure):
|
||||
v = t
|
||||
u = floatingpoint.BitDec(v, self.vlen + 2 + sfloat.round_nearest,
|
||||
self.vlen + 2 + sfloat.round_nearest, self.kappa,
|
||||
range(1 + sfloat.round_nearest,
|
||||
self.vlen + 2 + sfloat.round_nearest))
|
||||
list(range(1 + sfloat.round_nearest,
|
||||
self.vlen + 2 + sfloat.round_nearest)))
|
||||
# using u[0] doesn't seem necessary
|
||||
h = floatingpoint.PreOR(u[:sfloat.round_nearest:-1], self.kappa)
|
||||
p0 = self.vlen + 1 - sum(h)
|
||||
@@ -3013,7 +3054,7 @@ class sfloat(_number, _structure):
|
||||
def __rsub__(self, other):
|
||||
return -self + other
|
||||
|
||||
def __div__(self, other):
|
||||
def __truediv__(self, other):
|
||||
other = self.conv(other)
|
||||
v = floatingpoint.SDiv(self.v, other.v + other.z * (2**self.vlen - 1),
|
||||
self.vlen, self.kappa, self.round_nearest)
|
||||
@@ -3029,21 +3070,16 @@ class sfloat(_number, _structure):
|
||||
sfloat.set_error(other.z)
|
||||
return sfloat(v, p, z, s)
|
||||
|
||||
def __rdiv__(self, other):
|
||||
def __rtruediv__(self, other):
|
||||
return self.conv(other) / self
|
||||
|
||||
@vectorize
|
||||
def __neg__(self):
|
||||
return sfloat(self.v, self.p, self.z, (1 - self.s) * (1 - self.z))
|
||||
|
||||
def __abs__(self):
|
||||
if self.s:
|
||||
return -self
|
||||
else:
|
||||
return self
|
||||
|
||||
@vectorize
|
||||
def __lt__(self, other):
|
||||
other = self.conv(other)
|
||||
if isinstance(other, sfloat):
|
||||
z1 = self.z
|
||||
z2 = other.z
|
||||
@@ -3066,8 +3102,15 @@ class sfloat(_number, _structure):
|
||||
def __ge__(self, other):
|
||||
return 1 - (self < other)
|
||||
|
||||
def __gt__(self, other):
|
||||
return self.conv(other) < self
|
||||
|
||||
def __le__(self, other):
|
||||
return self.conv(other) >= self
|
||||
|
||||
@vectorize
|
||||
def __eq__(self, other):
|
||||
other = self.conv(other)
|
||||
# the sign can be both ways for zeroes
|
||||
both_zero = self.z * other.z
|
||||
return floatingpoint.EQZ(self.v - other.v, self.vlen, self.kappa) * \
|
||||
@@ -3151,24 +3194,25 @@ class Array(object):
|
||||
program.free(self.address, self.value_type.reg_type)
|
||||
|
||||
def get_address(self, index):
|
||||
key = str(index)
|
||||
if isinstance(index, int) and self.length is not None:
|
||||
index += self.length * (index < 0)
|
||||
if index >= self.length or index < 0:
|
||||
raise IndexError('index %s, length %s' % \
|
||||
(str(index), str(self.length)))
|
||||
if (program.curr_block, index) not in self.address_cache:
|
||||
if (program.curr_block, key) not in self.address_cache:
|
||||
n = self.value_type.n_elements()
|
||||
length = self.length
|
||||
if n == 1:
|
||||
# length can be None for single-element arrays
|
||||
length = 0
|
||||
self.address_cache[program.curr_block, index] = \
|
||||
self.address_cache[program.curr_block, key] = \
|
||||
util.untuplify([self.address + index + i * length \
|
||||
for i in range(n)])
|
||||
if self.debug:
|
||||
library.print_ln_if(index >= self.length, 'OF:' + self.debug)
|
||||
library.print_ln_if(self.address_cache[program.curr_block, index] >= program.allocated_mem[self.value_type.reg_type], 'AOF:' + self.debug)
|
||||
return self.address_cache[program.curr_block, index]
|
||||
library.print_ln_if(self.address_cache[program.curr_block, key] >= program.allocated_mem[self.value_type.reg_type], 'AOF:' + self.debug)
|
||||
return self.address_cache[program.curr_block, key]
|
||||
|
||||
def get_slice(self, index):
|
||||
if index.stop is None and self.length is None:
|
||||
@@ -3178,7 +3222,7 @@ class Array(object):
|
||||
def __getitem__(self, index):
|
||||
if isinstance(index, slice):
|
||||
start, stop, step = self.get_slice(index)
|
||||
res_length = (stop - start - 1) / step + 1
|
||||
res_length = (stop - start - 1) // step + 1
|
||||
res = Array(res_length, self.value_type)
|
||||
@library.for_range(res_length)
|
||||
def f(i):
|
||||
@@ -3303,7 +3347,7 @@ class SubMultiArray(object):
|
||||
def __getitem__(self, index):
|
||||
if util.is_constant(index) and index >= self.sizes[0]:
|
||||
raise StopIteration
|
||||
key = program.curr_block, index
|
||||
key = program.curr_block, str(index)
|
||||
if key not in self.sub_cache:
|
||||
if self.debug:
|
||||
library.print_ln_if(index >= self.sizes[0], \
|
||||
@@ -3531,7 +3575,7 @@ class _mem(_number):
|
||||
__add__ = lambda self,other: self.read() + other
|
||||
__sub__ = lambda self,other: self.read() - other
|
||||
__mul__ = lambda self,other: self.read() * other
|
||||
__div__ = lambda self,other: self.read() / other
|
||||
__truediv__ = lambda self,other: self.read() / other
|
||||
__mod__ = lambda self,other: self.read() % other
|
||||
__pow__ = lambda self,other: self.read() ** other
|
||||
__neg__ = lambda self,other: -self.read()
|
||||
@@ -3550,7 +3594,7 @@ class _mem(_number):
|
||||
__radd__ = lambda self,other: other + self.read()
|
||||
__rsub__ = lambda self,other: other - self.read()
|
||||
__rmul__ = lambda self,other: other * self.read()
|
||||
__rdiv__ = lambda self,other: other / self.read()
|
||||
__rtruediv__ = lambda self,other: other / self.read()
|
||||
__rmod__ = lambda self,other: other % self.read()
|
||||
__rand__ = lambda self,other: other & self.read()
|
||||
__rxor__ = lambda self,other: other ^ self.read()
|
||||
@@ -3627,7 +3671,7 @@ class MemValue(_mem):
|
||||
self.check()
|
||||
if isinstance(value, MemValue):
|
||||
self.register = value.read()
|
||||
elif isinstance(value, (int,long)):
|
||||
elif isinstance(value, int):
|
||||
self.register = self.value_type(value)
|
||||
else:
|
||||
self.register = value
|
||||
@@ -3717,7 +3761,7 @@ def getNamedTupleType(*names):
|
||||
class NamedTuple(object):
|
||||
class NamedTupleArray(object):
|
||||
def __init__(self, size, t):
|
||||
import types
|
||||
from . import types
|
||||
self.arrays = [types.Array(size, t) for i in range(len(names))]
|
||||
def __getitem__(self, index):
|
||||
return NamedTuple(array[index] for array in self.arrays)
|
||||
@@ -3749,4 +3793,4 @@ def getNamedTupleType(*names):
|
||||
return self.__type__(x.reveal() for x in self)
|
||||
return NamedTuple
|
||||
|
||||
import library
|
||||
from . import library
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import math
|
||||
import operator
|
||||
from functools import reduce
|
||||
|
||||
def format_trace(trace, prefix=' '):
|
||||
if trace is None:
|
||||
@@ -46,7 +47,7 @@ def right_shift(a, b, bits):
|
||||
return a.right_shift(b, bits)
|
||||
|
||||
def bit_decompose(a, bits):
|
||||
if isinstance(a, (int,long)):
|
||||
if isinstance(a, int):
|
||||
return [int((a >> i) & 1) for i in range(bits)]
|
||||
else:
|
||||
return a.bit_decompose(bits)
|
||||
@@ -82,7 +83,7 @@ def if_else(cond, a, b):
|
||||
else:
|
||||
return cond.if_else(a, b)
|
||||
except:
|
||||
print cond, a, b
|
||||
print(cond, a, b)
|
||||
raise
|
||||
|
||||
def cond_swap(cond, a, b):
|
||||
@@ -112,8 +113,8 @@ def tree_reduce(function, sequence):
|
||||
if n == 1:
|
||||
return sequence[0]
|
||||
else:
|
||||
reduced = [function(sequence[2*i], sequence[2*i+1]) for i in range(n/2)]
|
||||
return tree_reduce(function, reduced + sequence[n/2*2:])
|
||||
reduced = [function(sequence[2*i], sequence[2*i+1]) for i in range(n//2)]
|
||||
return tree_reduce(function, reduced + sequence[n//2*2:])
|
||||
|
||||
def or_op(a, b):
|
||||
return a + b - a * b
|
||||
@@ -144,7 +145,7 @@ def reveal(x):
|
||||
return x
|
||||
|
||||
def is_constant(x):
|
||||
return isinstance(x, (int, long, bool))
|
||||
return isinstance(x, (int, bool))
|
||||
|
||||
def is_constant_float(x):
|
||||
return isinstance(x, float) or is_constant(x)
|
||||
@@ -180,3 +181,44 @@ def expand(x, size):
|
||||
return x.expand_to_vector(size)
|
||||
except AttributeError:
|
||||
return x
|
||||
|
||||
class set_by_id(object):
|
||||
def __init__(self, init=[]):
|
||||
self.content = {}
|
||||
for x in init:
|
||||
self.add(x)
|
||||
|
||||
def __contains__(self, value):
|
||||
return id(value) in self.content
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.content.values())
|
||||
|
||||
def add(self, value):
|
||||
self.content[id(value)] = value
|
||||
|
||||
class dict_by_id(object):
|
||||
def __init__(self):
|
||||
self.content = {}
|
||||
|
||||
def __contains__(self, key):
|
||||
return id(key) in self.content
|
||||
|
||||
def __getitem__(self, key):
|
||||
return self.content[id(key)][1]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.content[id(key)] = (key, value)
|
||||
|
||||
def keys(self):
|
||||
return (x[0] for x in self.content.values())
|
||||
|
||||
class defaultdict_by_id(dict_by_id):
|
||||
def __init__(self, default):
|
||||
dict_by_id.__init__(self)
|
||||
self.default = default
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key not in self:
|
||||
self[key] = self.default()
|
||||
return dict_by_id.__getitem__(self, key)
|
||||
|
||||
66
ECDSA/EcdsaOptions.h
Normal file
66
ECDSA/EcdsaOptions.h
Normal file
@@ -0,0 +1,66 @@
|
||||
/*
|
||||
* EcdsaOptions.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef ECDSA_ECDSAOPTIONS_H_
|
||||
#define ECDSA_ECDSAOPTIONS_H_
|
||||
|
||||
#include "Tools/ezOptionParser.h"
|
||||
|
||||
class EcdsaOptions
|
||||
{
|
||||
public:
|
||||
bool prep_mul;
|
||||
bool fewer_rounds;
|
||||
bool check_open;
|
||||
bool check_beaver_open;
|
||||
|
||||
EcdsaOptions(ez::ezOptionParser& opt, int argc, const char** argv)
|
||||
{
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Delay multiplication until signing", // Help description.
|
||||
"-D", // Flag token.
|
||||
"--delay-multiplication" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Fewer rounds, more EC", // Help description.
|
||||
"-P", // Flag token.
|
||||
"--parallel-open" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Skip checking final openings (but not necessarily openings for Beaver; only relevant with active protocols)", // Help description.
|
||||
"-C", // Flag token.
|
||||
"--no-open-check" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Skip checking Beaver openings (only relevant with active protocols)", // Help description.
|
||||
"-B", // Flag token.
|
||||
"--no-beaver-open-check" // Flag token.
|
||||
);
|
||||
opt.parse(argc, argv);
|
||||
prep_mul = not opt.isSet("-D");
|
||||
fewer_rounds = opt.isSet("-P");
|
||||
check_open = not opt.isSet("-C");
|
||||
check_beaver_open = not opt.isSet("-B");
|
||||
opt.resetArgs();
|
||||
}
|
||||
};
|
||||
|
||||
#endif /* ECDSA_ECDSAOPTIONS_H_ */
|
||||
@@ -18,7 +18,7 @@ int main()
|
||||
string prefix = PREP_DIR "ECDSA/";
|
||||
mkdir_p(prefix.c_str());
|
||||
ofstream outf;
|
||||
write_online_setup(outf, prefix, P256Element::Scalar::pr(), 0, false);
|
||||
write_online_setup_without_init(outf, prefix, P256Element::Scalar::pr(), 0);
|
||||
generate_mac_keys<Share<P256Element::Scalar>>(key, key2, 2, prefix);
|
||||
make_mult_triples<Share<P256Element::Scalar>>(key, 2, 1000, false, prefix);
|
||||
make_inverse<Share<P256Element::Scalar>>(key, 2, 1000, false, prefix);
|
||||
|
||||
@@ -77,6 +77,14 @@ P256Element& P256Element::operator +=(const P256Element& other)
|
||||
return *this;
|
||||
}
|
||||
|
||||
P256Element& P256Element::operator /=(const Scalar& other)
|
||||
{
|
||||
auto tmp = other;
|
||||
tmp.invert();
|
||||
*this = *this * tmp;
|
||||
return *this;
|
||||
}
|
||||
|
||||
bool P256Element::operator ==(const P256Element& other) const
|
||||
{
|
||||
return point == other.point;
|
||||
|
||||
@@ -10,14 +10,10 @@
|
||||
|
||||
#include "Math/gfp.h"
|
||||
|
||||
#if GFP_MOD_SZ != 4
|
||||
#error GFP_MOD_SZ must be 4
|
||||
#endif
|
||||
|
||||
class P256Element : public ValueInterface
|
||||
{
|
||||
public:
|
||||
typedef gfp Scalar;
|
||||
typedef gfp_<2, 4> Scalar;
|
||||
|
||||
private:
|
||||
static CryptoPP::DL_GroupParameters_EC<CryptoPP::ECP> params;
|
||||
@@ -52,6 +48,7 @@ public:
|
||||
P256Element operator*(const Scalar& other) const;
|
||||
|
||||
P256Element& operator+=(const P256Element& other);
|
||||
P256Element& operator/=(const Scalar& other);
|
||||
|
||||
bool operator==(const P256Element& other) const;
|
||||
bool operator!=(const P256Element& other) const;
|
||||
|
||||
@@ -5,8 +5,7 @@ in `preprocessing.hpp` and `sign.hpp`, respectively.
|
||||
|
||||
#### Compilation
|
||||
|
||||
- Add `MOD = -DGFP_MOD_SZ=4` to `CONFIG.mine`
|
||||
- Also consider adding either `CXX = clang++` or `OPTIM = -O2` because GCC 8 or later with `-O3` will produce a segfault when using `mascot-ecdsa-party.x`
|
||||
- Add either `CXX = clang++` or `OPTIM = -O2` because GCC 8 or later with `-O3` will produce a segfault when using `mascot-ecdsa-party.x`
|
||||
- For older hardware, also add `ARCH = -march=native`
|
||||
- Install [Crypto++](https://www.cryptopp.com) (`libcrypto++-dev` on Ubuntu). We used version 5.6.4, which is the default on Ubuntu 18.04.
|
||||
- Compile the binaries: `make -j8 ecdsa`
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
ez::ezOptionParser opt;
|
||||
EcdsaOptions opts(opt, argc, argv);
|
||||
Names N(opt, argc, argv, 2);
|
||||
int n_tuples = 1000;
|
||||
if (not opt.lastArgs.empty())
|
||||
@@ -38,7 +39,7 @@ int main(int argc, const char** argv)
|
||||
typedef Share<P256Element::Scalar> pShare;
|
||||
DataPositions usage;
|
||||
Sub_Data_Files<pShare> prep(N, prefix, usage);
|
||||
typename pShare::MAC_Check MCp(keyp);
|
||||
typename pShare::Direct_MC MCp(keyp);
|
||||
ArithmeticProcessor _({}, 0);
|
||||
SubProcessor<pShare> proc(_, MCp, prep, P);
|
||||
|
||||
@@ -46,7 +47,7 @@ int main(int argc, const char** argv)
|
||||
proc.DataF.get_two(DATA_INVERSE, sk, __);
|
||||
|
||||
vector<EcTuple<Share>> tuples;
|
||||
preprocessing(tuples, n_tuples, sk, proc);
|
||||
preprocessing(tuples, n_tuples, sk, proc, opts);
|
||||
check(tuples, sk, keyp, P);
|
||||
sign_benchmark(tuples, sk, MCp, P);
|
||||
sign_benchmark(tuples, sk, MCp, P, opts);
|
||||
}
|
||||
|
||||
@@ -29,15 +29,7 @@ void run(int argc, const char** argv)
|
||||
{
|
||||
bigint::init_thread();
|
||||
ez::ezOptionParser opt;
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Delay multiplication until signing", // Help description.
|
||||
"-D", // Flag token.
|
||||
"--delay-multiplication" // Flag token.
|
||||
);
|
||||
EcdsaOptions opts(opt, argc, argv);
|
||||
Names N(opt, argc, argv, 3);
|
||||
int n_tuples = 1000;
|
||||
if (not opt.lastArgs.empty())
|
||||
@@ -51,10 +43,12 @@ void run(int argc, const char** argv)
|
||||
P.Broadcast_Receive(bundle, false);
|
||||
Timer timer;
|
||||
timer.start();
|
||||
auto stats = P.comm_stats;
|
||||
pShare sk = typename T<P256Element::Scalar>::Honest::Protocol(P).get_random();
|
||||
cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl;
|
||||
(P.comm_stats - stats).print(true);
|
||||
|
||||
OnlineOptions::singleton.batch_size = n_tuples;
|
||||
OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples;
|
||||
DataPositions usage;
|
||||
auto& prep = *Preprocessing<pShare>::get_live_prep(0, usage);
|
||||
typename pShare::MAC_Check MCp;
|
||||
@@ -63,9 +57,9 @@ void run(int argc, const char** argv)
|
||||
|
||||
bool prep_mul = not opt.isSet("-D");
|
||||
vector<EcTuple<T>> tuples;
|
||||
preprocessing(tuples, n_tuples, sk, proc, prep_mul);
|
||||
preprocessing(tuples, n_tuples, sk, proc, opts);
|
||||
// check(tuples, sk, {}, P);
|
||||
sign_benchmark(tuples, sk, MCp, P, prep_mul ? 0 : &proc);
|
||||
sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc);
|
||||
|
||||
delete &prep;
|
||||
}
|
||||
|
||||
@@ -25,27 +25,72 @@ template<template<class U> class T>
|
||||
void run(int argc, const char** argv)
|
||||
{
|
||||
ez::ezOptionParser opt;
|
||||
EcdsaOptions opts(opt, argc, argv);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Delay multiplication until signing", // Help description.
|
||||
"-D", // Flag token.
|
||||
"--delay-multiplication" // Flag token.
|
||||
"Use SimpleOT instead of OT extension", // Help description.
|
||||
"-S", // Flag token.
|
||||
"--simple-ot" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Don't check correlation in OT extension (only relevant with MASCOT)", // Help description.
|
||||
"-U", // Flag token.
|
||||
"--unchecked-correlation" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Fewer rounds for authentication (only relevant with MASCOT)", // Help description.
|
||||
"-A", // Flag token.
|
||||
"--auth-fewer-rounds" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Use Fiat-Shamir for amplification (only relevant with MASCOT)", // Help description.
|
||||
"-H", // Flag token.
|
||||
"--fiat-shamir" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Skip sacrifice (only relevant with MASCOT)", // Help description.
|
||||
"-E", // Flag token.
|
||||
"--embrace-life" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
"", // Default.
|
||||
0, // Required?
|
||||
0, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"No MACs (only relevant with MASCOT; implies skipping MAC checks)", // Help description.
|
||||
"-M", // Flag token.
|
||||
"--no-macs" // Flag token.
|
||||
);
|
||||
|
||||
Names N(opt, argc, argv, 2);
|
||||
int n_tuples = 1000;
|
||||
if (not opt.lastArgs.empty())
|
||||
n_tuples = atoi(opt.lastArgs[0]->c_str());
|
||||
PlainPlayer P(N);
|
||||
P256Element::init();
|
||||
gfp1::init_field(P256Element::Scalar::pr(), false);
|
||||
P256Element::Scalar::next::init_field(P256Element::Scalar::pr(), false);
|
||||
|
||||
BaseMachine machine;
|
||||
machine.ot_setups.resize(1);
|
||||
for (int i = 0; i < 2; i++)
|
||||
machine.ot_setups[0].push_back({P, true});
|
||||
machine.ot_setups.push_back({P, true});
|
||||
|
||||
P256Element::Scalar keyp;
|
||||
SeededPRNG G;
|
||||
@@ -65,16 +110,28 @@ void run(int argc, const char** argv)
|
||||
P.Broadcast_Receive(bundle, false);
|
||||
Timer timer;
|
||||
timer.start();
|
||||
auto stats = P.comm_stats;
|
||||
sk_prep.get_two(DATA_INVERSE, sk, __);
|
||||
cout << "Secret key generation took " << timer.elapsed() * 1e3 << " ms" << endl;
|
||||
(P.comm_stats - stats).print(true);
|
||||
|
||||
OnlineOptions::singleton.batch_size = n_tuples;
|
||||
OnlineOptions::singleton.batch_size = (1 + pShare::Protocol::uses_triples) * n_tuples;
|
||||
typename pShare::LivePrep prep(0, usage);
|
||||
prep.params.correlation_check &= not opt.isSet("-U");
|
||||
prep.params.fewer_rounds = opt.isSet("-A");
|
||||
prep.params.fiat_shamir = opt.isSet("-H");
|
||||
prep.params.check = not opt.isSet("-E");
|
||||
prep.params.generateMACs = not opt.isSet("-M");
|
||||
opts.check_beaver_open &= prep.params.generateMACs;
|
||||
opts.check_open &= prep.params.generateMACs;
|
||||
SubProcessor<pShare> proc(_, MCp, prep, P);
|
||||
typename pShare::prep_type::Direct_MC MCpp(keyp);
|
||||
prep.triple_generator->MC = &MCpp;
|
||||
|
||||
bool prep_mul = not opt.isSet("-D");
|
||||
prep.params.use_extension = not opt.isSet("-S");
|
||||
vector<EcTuple<T>> tuples;
|
||||
preprocessing(tuples, n_tuples, sk, proc, prep_mul);
|
||||
preprocessing(tuples, n_tuples, sk, proc, opts);
|
||||
//check(tuples, sk, keyp, P);
|
||||
sign_benchmark(tuples, sk, MCp, P, prep_mul ? 0 : &proc);
|
||||
sign_benchmark(tuples, sk, MCp, P, opts, prep_mul ? 0 : &proc);
|
||||
}
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#define ECDSA_PREPROCESSING_HPP_
|
||||
|
||||
#include "P256Element.h"
|
||||
#include "EcdsaOptions.h"
|
||||
#include "Processor/Data_Files.h"
|
||||
#include "Protocols/ReplicatedPrep.h"
|
||||
#include "Protocols/MaliciousShamirShare.h"
|
||||
@@ -23,40 +24,66 @@ public:
|
||||
template<template<class U> class T>
|
||||
void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
|
||||
T<P256Element::Scalar>& sk,
|
||||
SubProcessor<T<P256Element::Scalar>>& proc, bool prep_mul = true)
|
||||
SubProcessor<T<P256Element::Scalar>>& proc,
|
||||
EcdsaOptions opts)
|
||||
{
|
||||
bool prep_mul = opts.prep_mul;
|
||||
Timer timer;
|
||||
timer.start();
|
||||
Player& P = proc.P;
|
||||
auto& prep = proc.DataF;
|
||||
size_t start = P.sent + prep.data_sent();
|
||||
auto stats = P.comm_stats + prep.comm_stats();
|
||||
auto& extra_player = P;
|
||||
|
||||
auto& protocol = proc.protocol;
|
||||
auto& MCp = proc.MC;
|
||||
typedef T<typename P256Element::Scalar> pShare;
|
||||
typedef T<P256Element> cShare;
|
||||
vector<pShare> inv_ks;
|
||||
vector<cShare> secret_Rs;
|
||||
prep.buffer_triples();
|
||||
vector<pShare> bs, cs;
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
{
|
||||
pShare a, a_inv;
|
||||
prep.get_two(DATA_INVERSE, a, a_inv);
|
||||
inv_ks.push_back(a_inv);
|
||||
secret_Rs.push_back(a);
|
||||
pShare a, b, c;
|
||||
prep.get_three(DATA_TRIPLE, a, b, c);
|
||||
inv_ks.push_back(a);
|
||||
bs.push_back(b);
|
||||
cs.push_back(c);
|
||||
}
|
||||
vector<P256Element::Scalar> cs_opened;
|
||||
MCp.POpen_Begin(cs_opened, cs, extra_player);
|
||||
if (opts.fewer_rounds)
|
||||
secret_Rs.insert(secret_Rs.begin(), bs.begin(), bs.end());
|
||||
else
|
||||
{
|
||||
MCp.POpen_End(cs_opened, cs, extra_player);
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
secret_Rs.push_back(bs[i] / cs_opened[i]);
|
||||
}
|
||||
vector<P256Element> opened_Rs;
|
||||
typename cShare::Direct_MC MCc(MCp.get_alphai());
|
||||
MCc.POpen_Begin(opened_Rs, secret_Rs, extra_player);
|
||||
if (prep_mul)
|
||||
{
|
||||
protocol.init_mul(&proc);
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
protocol.prepare_mul(inv_ks[i], sk);
|
||||
protocol.exchange();
|
||||
protocol.start_exchange();
|
||||
}
|
||||
else
|
||||
prep.buffer_triples();
|
||||
vector<P256Element> opened_Rs;
|
||||
typename cShare::MAC_Check MCc(MCp.get_alphai());
|
||||
MCc.POpen(opened_Rs, secret_Rs, P);
|
||||
MCc.Check(P);
|
||||
MCp.Check(P);
|
||||
if (opts.fewer_rounds)
|
||||
MCp.POpen_End(cs_opened, cs, extra_player);
|
||||
MCc.POpen_End(opened_Rs, secret_Rs, extra_player);
|
||||
if (opts.fewer_rounds)
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
opened_Rs[i] /= cs_opened[i];
|
||||
if (prep_mul)
|
||||
protocol.stop_exchange();
|
||||
if (opts.check_open)
|
||||
MCc.Check(extra_player);
|
||||
if (opts.check_open or opts.check_beaver_open)
|
||||
MCp.Check(extra_player);
|
||||
for (int i = 0; i < buffer_size; i++)
|
||||
{
|
||||
tuples.push_back(
|
||||
@@ -68,6 +95,7 @@ void preprocessing(vector<EcTuple<T>>& tuples, int buffer_size,
|
||||
<< " seconds, throughput " << buffer_size / timer.elapsed() << ", "
|
||||
<< 1e-3 * (P.sent + prep.data_sent() - start) / buffer_size
|
||||
<< " kbytes per tuple" << endl;
|
||||
(P.comm_stats + prep.comm_stats() - stats).print(true);
|
||||
}
|
||||
|
||||
template<template<class U> class T>
|
||||
|
||||
@@ -8,10 +8,10 @@
|
||||
#include "hm-ecdsa-party.hpp"
|
||||
|
||||
template<>
|
||||
Preprocessing<Rep3Share<gfp>>* Preprocessing<Rep3Share<gfp>>::get_live_prep(
|
||||
SubProcessor<Rep3Share<gfp>>* proc, DataPositions& usage)
|
||||
Preprocessing<Rep3Share<gfp2>>* Preprocessing<Rep3Share<gfp2>>::get_live_prep(
|
||||
SubProcessor<Rep3Share<gfp2>>* proc, DataPositions& usage)
|
||||
{
|
||||
return new ReplicatedPrep<Rep3Share<gfp>>(proc, usage);
|
||||
return new ReplicatedPrep<Rep3Share<gfp2>>(proc, usage);
|
||||
}
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
|
||||
@@ -71,12 +71,12 @@ EcSignature sign(const unsigned char* message, size_t length,
|
||||
prod = protocol.finalize_mul();
|
||||
}
|
||||
auto rx = tuple.R.x();
|
||||
signature.s = MC.POpen(
|
||||
signature.s = MC.open(
|
||||
tuple.a * hash_to_scalar(message, length) + prod * rx, P);
|
||||
cout << "Minimal signing took " << timer.elapsed() * 1e3 << " ms and sending "
|
||||
<< (P.sent - start) << " bytes" << endl;
|
||||
auto diff = (P.comm_stats - stats);
|
||||
diff.print();
|
||||
diff.print(true);
|
||||
return signature;
|
||||
}
|
||||
|
||||
@@ -112,26 +112,30 @@ void check(EcSignature signature, const unsigned char* message, size_t length,
|
||||
template<template<class U> class T>
|
||||
void sign_benchmark(vector<EcTuple<T>>& tuples, T<P256Element::Scalar> sk,
|
||||
typename T<P256Element::Scalar>::MAC_Check& MCp, Player& P,
|
||||
EcdsaOptions& opts,
|
||||
SubProcessor<T<P256Element::Scalar>>* proc = 0)
|
||||
{
|
||||
unsigned char message[1024];
|
||||
GlobalPRNG(P).get_octets(message, 1024);
|
||||
typename T<P256Element>::MAC_Check MCc(MCp.get_alphai());
|
||||
typename T<P256Element>::Direct_MC MCc(MCp.get_alphai());
|
||||
|
||||
// synchronize
|
||||
Bundle<octetStream> bundle(P);
|
||||
P.Broadcast_Receive(bundle, true);
|
||||
Timer timer;
|
||||
timer.start();
|
||||
P256Element pk = MCc.POpen(sk, P);
|
||||
auto stats = P.comm_stats;
|
||||
P256Element pk = MCc.open(sk, P);
|
||||
MCc.Check(P);
|
||||
cout << "Public key generation took " << timer.elapsed() * 1e3 << " ms" << endl;
|
||||
P.comm_stats.print();
|
||||
(P.comm_stats - stats).print(true);
|
||||
|
||||
for (size_t i = 0; i < min(10lu, tuples.size()); i++)
|
||||
{
|
||||
check(sign(message, 1 << i, tuples[i], MCp, P, pk, sk, proc), message,
|
||||
1 << i, pk);
|
||||
if (not opts.check_open)
|
||||
continue;
|
||||
Timer timer;
|
||||
timer.start();
|
||||
auto& check_player = MCp.get_check_player(P);
|
||||
|
||||
@@ -365,6 +365,7 @@ int main(int argc, char** argv)
|
||||
// init static gfp
|
||||
string prep_data_prefix = get_prep_dir(nparties, 128, gf2n::default_degree());
|
||||
initialise_fields(prep_data_prefix);
|
||||
bigint::init_thread();
|
||||
|
||||
// Generate session keys to decrypt data sent from each spdz engine (party)
|
||||
vector<octet*> session_keys(nparties);
|
||||
|
||||
@@ -39,6 +39,8 @@ class Ciphertext
|
||||
|
||||
// Rely on default copy assignment/constructor
|
||||
|
||||
word get_pk_id() { return pk_id; }
|
||||
|
||||
void set(const Rq_Element& a0, const Rq_Element& a1, word pk_id)
|
||||
{ cc0=a0; cc1=a1; this->pk_id = pk_id; }
|
||||
void set(const Rq_Element& a0, const Rq_Element& a1, const FHE_PK& pk);
|
||||
|
||||
@@ -9,13 +9,20 @@
|
||||
|
||||
template <class FD>
|
||||
Multiplier<FD>::Multiplier(int offset, PairwiseGenerator<FD>& generator) :
|
||||
generator(generator), machine(generator.machine),
|
||||
P(generator.P, offset),
|
||||
num_players(generator.P.num_players()),
|
||||
my_num(generator.P.my_num()),
|
||||
Multiplier(offset, generator.machine, generator.P, generator.timers)
|
||||
{
|
||||
}
|
||||
|
||||
template <class FD>
|
||||
Multiplier<FD>::Multiplier(int offset, PairwiseMachine& machine, Player& P,
|
||||
map<string, Timer>& timers) :
|
||||
machine(machine),
|
||||
P(P, offset),
|
||||
num_players(P.num_players()),
|
||||
my_num(P.my_num()),
|
||||
other_pk(machine.other_pks[(my_num + num_players - offset) % num_players]),
|
||||
other_enc_alpha(machine.enc_alphas[(my_num + num_players - offset) % num_players]),
|
||||
timers(generator.timers),
|
||||
timers(timers),
|
||||
C(machine.pk), mask(machine.pk),
|
||||
product_share(machine.setup<FD>().FieldD), rc(machine.pk),
|
||||
volatile_capacity(0)
|
||||
|
||||
@@ -21,7 +21,6 @@ class PairwiseMachine;
|
||||
template <class FD>
|
||||
class Multiplier
|
||||
{
|
||||
PairwiseGenerator<FD>& generator;
|
||||
PairwiseMachine& machine;
|
||||
OffsetPlayer P;
|
||||
int num_players, my_num;
|
||||
@@ -39,6 +38,9 @@ class Multiplier
|
||||
|
||||
public:
|
||||
Multiplier(int offset, PairwiseGenerator<FD>& generator);
|
||||
Multiplier(int offset, PairwiseMachine& machine, Player& P,
|
||||
map<string, Timer>& timers);
|
||||
|
||||
void multiply_and_add(Plaintext_<FD>& res, const Ciphertext& C,
|
||||
const Plaintext_<FD>& b);
|
||||
void multiply_and_add(Plaintext_<FD>& res, const Ciphertext& C,
|
||||
|
||||
@@ -13,6 +13,8 @@ using namespace std;
|
||||
#include "Access.h"
|
||||
#include "Processor/FixInput.h"
|
||||
|
||||
#include "Processor/ProcessorBase.hpp"
|
||||
|
||||
namespace GC
|
||||
{
|
||||
|
||||
|
||||
@@ -25,9 +25,9 @@ void SemiPrep::set_protocol(Beaver<SemiSecret>& protocol)
|
||||
(void) protocol;
|
||||
params.set_passive();
|
||||
triple_generator = new SemiSecret::TripleGenerator(
|
||||
thread.processor.machine.ot_setups.at(thread.thread_num).at(0),
|
||||
thread.processor.machine.ot_setups.at(thread.thread_num).get_fresh(),
|
||||
thread.master.N, thread.thread_num, thread.master.opts.batch_size,
|
||||
1, params, thread.P);
|
||||
1, params, {}, thread.P);
|
||||
triple_generator->multi_threaded = false;
|
||||
}
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ ShareParty<T>::ShareParty(int argc, const char** argv, int default_batch_size) :
|
||||
else
|
||||
P = new PlainPlayer(this->N, 0xFFFF);
|
||||
for (int i = 0; i < this->machine.nthreads; i++)
|
||||
this->machine.ot_setups.push_back({{{*P, true}}});
|
||||
this->machine.ot_setups.push_back({*P, true});
|
||||
delete P;
|
||||
}
|
||||
|
||||
|
||||
@@ -152,8 +152,7 @@ void ReplicatedSecret<U>::reveal(size_t n_bits, Clear& x)
|
||||
auto& share = *this;
|
||||
vector<BitVec> opened;
|
||||
auto& party = ShareThread<U>::s();
|
||||
party.MC->POpen_Begin(opened, {share}, *party.P);
|
||||
party.MC->POpen_End(opened, {share}, *party.P);
|
||||
party.MC->POpen(opened, {share}, *party.P);
|
||||
x = IntBase(opened[0]);
|
||||
}
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ void ThreadMaster<T>::run()
|
||||
|
||||
if (T::needs_ot)
|
||||
for (int i = 0; i < machine.nthreads; i++)
|
||||
machine.ot_setups.push_back({{*P, true}, {*P, true}});
|
||||
machine.ot_setups.push_back({*P, true});
|
||||
|
||||
for (int i = 0; i < machine.nthreads; i++)
|
||||
threads.push_back(new_thread(i));
|
||||
|
||||
@@ -31,18 +31,17 @@ void TinyPrep<T>::set_protocol(Beaver<T>& protocol)
|
||||
params.generateMACs = true;
|
||||
params.amplify = false;
|
||||
params.check = false;
|
||||
params.set_mac_key(thread.MC->get_alphai());
|
||||
triple_generator = new typename T::TripleGenerator(
|
||||
thread.processor.machine.ot_setups.at(thread.thread_num).at(0),
|
||||
thread.processor.machine.ot_setups.at(thread.thread_num).get_fresh(),
|
||||
thread.master.N, thread.thread_num,
|
||||
thread.master.opts.batch_size,
|
||||
1, params, thread.P);
|
||||
1, params, thread.MC->get_alphai(), thread.P);
|
||||
triple_generator->multi_threaded = false;
|
||||
input_generator = new typename T::part_type::TripleGenerator(
|
||||
thread.processor.machine.ot_setups.at(thread.thread_num).at(1),
|
||||
thread.processor.machine.ot_setups.at(thread.thread_num).get_fresh(),
|
||||
thread.master.N, thread.thread_num,
|
||||
thread.master.opts.batch_size,
|
||||
1, params, thread.P);
|
||||
1, params, thread.MC->get_alphai(), thread.P);
|
||||
input_generator->multi_threaded = false;
|
||||
thread.MC->get_part_MC().set_prep(*this);
|
||||
}
|
||||
|
||||
@@ -264,11 +264,8 @@ OTMachine::OTMachine(int argc, const char** argv)
|
||||
|
||||
// convert baseOT selection bits to BitVector
|
||||
// (not already BitVector due to legacy PVW code)
|
||||
baseReceiverInput = bot.receiver_inputs;
|
||||
baseReceiverInput.resize(nbase);
|
||||
for (int i = 0; i < nbase; i++)
|
||||
{
|
||||
baseReceiverInput.set_bit(i, bot.receiver_inputs[i]);
|
||||
}
|
||||
}
|
||||
|
||||
OTMachine::~OTMachine()
|
||||
|
||||
@@ -107,7 +107,7 @@ int spdz_main(int argc, const char** argv, ez::ezOptionParser& opt, bool live_pr
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"Maximum number of parties to send to at once", // Help description.
|
||||
"-b", // Flag token.
|
||||
"-B", // Flag token.
|
||||
"--max-broadcast" // Flag token.
|
||||
);
|
||||
opt.add(
|
||||
|
||||
@@ -118,6 +118,7 @@ TripleMachine::TripleMachine(int argc, const char** argv) :
|
||||
opt.get("-l")->getInt(nloops);
|
||||
generateBits = opt.get("-B")->isSet;
|
||||
check = opt.get("-c")->isSet || generateBits;
|
||||
correlation_check = opt.get("-c")->isSet;
|
||||
generateMACs = opt.get("-m")->isSet || check;
|
||||
amplify = opt.get("-a")->isSet || generateMACs;
|
||||
primeField = opt.get("-P")->isSet;
|
||||
@@ -143,21 +144,22 @@ TripleMachine::TripleMachine(int argc, const char** argv) :
|
||||
|
||||
// doesn't work with Montgomery multiplication
|
||||
gfp1::init_field(p, false);
|
||||
gfp::init_field(p, true);
|
||||
gf2n_long::init_field(128);
|
||||
|
||||
PRNG G;
|
||||
G.ReSeed();
|
||||
mac_key2l.randomize(G);
|
||||
mac_key2s.randomize(G);
|
||||
mac_key2.randomize(G);
|
||||
mac_keyp.randomize(G);
|
||||
mac_keyz.randomize(G);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
GeneratorThread* TripleMachine::new_generator(OTTripleSetup& setup, int i)
|
||||
GeneratorThread* TripleMachine::new_generator(OTTripleSetup& setup, int i,
|
||||
typename T::mac_key_type mac_key)
|
||||
{
|
||||
return new typename T::TripleGenerator(setup, N[i % nConnections], i,
|
||||
nTriplesPerThread, nloops, *this);
|
||||
nTriplesPerThread, nloops, *this, mac_key);
|
||||
}
|
||||
|
||||
void TripleMachine::run()
|
||||
@@ -180,24 +182,24 @@ void TripleMachine::run()
|
||||
for (int i = 0; i < nthreads; i++)
|
||||
{
|
||||
if (primeField)
|
||||
generators[i] = new_generator<Share<gfp>>(setup, i);
|
||||
generators[i] = new_generator<Share<gfp>>(setup, i, mac_keyp);
|
||||
else if (z2k)
|
||||
{
|
||||
if (z2k == 32 and z2s == 32)
|
||||
generators[i] = new_generator<Spdz2kShare<32, 32>>(setup, i);
|
||||
generators[i] = new_generator<Spdz2kShare<32, 32>>(setup, i, mac_keyz);
|
||||
else if (z2k == 64 and z2s == 64)
|
||||
generators[i] = new_generator<Spdz2kShare<64, 64>>(setup, i);
|
||||
generators[i] = new_generator<Spdz2kShare<64, 64>>(setup, i, mac_keyz);
|
||||
else if (z2k == 64 and z2s == 48)
|
||||
generators[i] = new_generator<Spdz2kShare<64, 48>>(setup, i);
|
||||
generators[i] = new_generator<Spdz2kShare<64, 48>>(setup, i, mac_keyz);
|
||||
else if (z2k == 66 and z2s == 64)
|
||||
generators[i] = new_generator<Spdz2kShare<66, 64>>(setup, i);
|
||||
generators[i] = new_generator<Spdz2kShare<66, 64>>(setup, i, mac_keyz);
|
||||
else if (z2k == 66 and z2s == 48)
|
||||
generators[i] = new_generator<Spdz2kShare<66, 48>>(setup, i);
|
||||
generators[i] = new_generator<Spdz2kShare<66, 48>>(setup, i, mac_keyz);
|
||||
else
|
||||
throw runtime_error("not compiled for k=" + to_string(z2k) + " and s=" + to_string(z2s));
|
||||
}
|
||||
else
|
||||
generators[i] = new_generator<Share<gf2n>>(setup, i);
|
||||
generators[i] = new_generator<Share<gf2n>>(setup, i, mac_key2);
|
||||
}
|
||||
ntriples = generators[0]->nTriples * nthreads;
|
||||
cout <<"Setup generators\n";
|
||||
@@ -251,10 +253,8 @@ void TripleMachine::run()
|
||||
void TripleMachine::output_mac_keys()
|
||||
{
|
||||
if (z2k) {
|
||||
write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyz, mac_key2l);
|
||||
write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyz, mac_key2);
|
||||
}
|
||||
else if (gf2n::degree() > 64)
|
||||
write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyp, mac_key2l);
|
||||
else
|
||||
write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyp, mac_key2s);
|
||||
write_mac_keys(prep_data_dir, my_num, nplayers, mac_keyp, mac_key2);
|
||||
}
|
||||
|
||||
29
Machines/hemi-party.cpp
Normal file
29
Machines/hemi-party.cpp
Normal file
@@ -0,0 +1,29 @@
|
||||
/*
|
||||
* hemi-party.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "Protocols/HemiShare.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Math/gf2n.h"
|
||||
#include "FHE/P2Data.h"
|
||||
#include "Tools/ezOptionParser.h"
|
||||
|
||||
#include "Player-Online.hpp"
|
||||
#include "Protocols/HemiPrep.hpp"
|
||||
#include "Processor/Data_Files.hpp"
|
||||
#include "Processor/Instruction.hpp"
|
||||
#include "Processor/Machine.hpp"
|
||||
#include "Protocols/SemiPrep.hpp"
|
||||
#include "Protocols/SemiInput.hpp"
|
||||
#include "Protocols/MAC_Check_Base.hpp"
|
||||
#include "Protocols/MAC_Check.hpp"
|
||||
#include "Protocols/fake-stuff.hpp"
|
||||
#include "Protocols/SemiMC.hpp"
|
||||
#include "Protocols/Beaver.hpp"
|
||||
|
||||
int main(int argc, const char** argv)
|
||||
{
|
||||
ez::ezOptionParser opt;
|
||||
spdz_main<HemiShare<gfp>, HemiShare<gf2n_short>>(argc, argv, opt);
|
||||
}
|
||||
3
Makefile
3
Makefile
@@ -38,7 +38,7 @@ DEPS := $(wildcard */*.d)
|
||||
all: gen_input online offline externalIO bmr yao replicated shamir real-bmr spdz2k-party.x brain-party.x semi-party.x semi2k-party.x semi-bin-party.x mascot-party.x tiny-party.x
|
||||
|
||||
ifeq ($(USE_NTL),1)
|
||||
all: overdrive she-offline cowgear-party.x
|
||||
all: overdrive she-offline cowgear-party.x hemi-party.x
|
||||
endif
|
||||
|
||||
-include $(DEPS)
|
||||
@@ -165,6 +165,7 @@ malicious-shamir-party.x: Machines/ShamirMachine.o
|
||||
spdz2k-party.x: $(OT)
|
||||
semi-party.x: $(OT)
|
||||
semi2k-party.x: $(OT)
|
||||
hemi-party.x: $(FHEOFFLINE)
|
||||
cowgear-party.x: $(FHEOFFLINE) Protocols/CowGearOptions.o
|
||||
mascot-party.x: Machines/SPDZ.o $(OT)
|
||||
Player-Online.x: Machines/SPDZ.o $(OT)
|
||||
|
||||
@@ -110,6 +110,13 @@ void generate_online_setup(ofstream& outf, string dirname, bigint& p, int lgp, i
|
||||
}
|
||||
|
||||
void write_online_setup(ofstream& outf, string dirname, const bigint& p, int lg2, bool mont)
|
||||
{
|
||||
write_online_setup_without_init(outf, dirname, p, lg2);
|
||||
gfp::init_field(p, mont);
|
||||
init_gf2n(lg2);
|
||||
}
|
||||
|
||||
void write_online_setup_without_init(ofstream& outf, string dirname, const bigint& p, int lg2)
|
||||
{
|
||||
if (p == 0)
|
||||
throw runtime_error("prime cannot be 0");
|
||||
@@ -132,9 +139,6 @@ void write_online_setup(ofstream& outf, string dirname, const bigint& p, int lg2
|
||||
// Fix as a negative lg2 is a ``signal'' to choose slightly weaker
|
||||
// LWE parameters
|
||||
outf << abs(lg2) << endl;
|
||||
|
||||
gfp::init_field(p, mont);
|
||||
init_gf2n(lg2);
|
||||
}
|
||||
|
||||
void init_gf2n(int lg2)
|
||||
|
||||
@@ -22,6 +22,7 @@ using namespace std;
|
||||
// Create setup file for gfp and gf2n
|
||||
void generate_online_setup(ofstream& outf, string dirname, bigint& p, int lgp, int lg2);
|
||||
void write_online_setup(ofstream& outf, string dirname, const bigint& p, int lg2, bool mont = true);
|
||||
void write_online_setup_without_init(ofstream& outf, string dirname, const bigint& p, int lg2);
|
||||
|
||||
// Setup primes only
|
||||
// Chooses a p of at least lgp bits
|
||||
|
||||
@@ -15,14 +15,14 @@ void Square<gf2n_short>::to(gf2n_short& result)
|
||||
result = sum;
|
||||
}
|
||||
|
||||
template <>
|
||||
void Square<gfp1>::to(gfp1& result)
|
||||
template<class U>
|
||||
template<int X, int L>
|
||||
void Square<U>::to(gfp_<X, L>& result)
|
||||
{
|
||||
const int L = gfp1::N_LIMBS;
|
||||
mp_limb_t product[2 * L], sum[2 * L], tmp[L][2 * L];
|
||||
memset(tmp, 0, sizeof(tmp));
|
||||
memset(sum, 0, sizeof(sum));
|
||||
for (int i = 0; i < gfp1::length(); i++)
|
||||
for (int i = 0; i < gfp_<X, L>::length(); i++)
|
||||
{
|
||||
memcpy(&(tmp[i/64][i/64]), &(rows[i]), sizeof(rows[i]));
|
||||
if (i % 64 == 0)
|
||||
@@ -32,10 +32,22 @@ void Square<gfp1>::to(gfp1& result)
|
||||
mpn_add_fixed_n<2 * L>(sum, product, sum);
|
||||
}
|
||||
mp_limb_t q[2 * L], ans[2 * L];
|
||||
mpn_tdiv_qr(q, ans, 0, sum, 2 * L, gfp1::get_ZpD().get_prA(), L);
|
||||
mpn_tdiv_qr(q, ans, 0, sum, 2 * L, gfp_<X, L>::get_ZpD().get_prA(), L);
|
||||
result.assign((void*) ans);
|
||||
}
|
||||
|
||||
template<>
|
||||
void Square<gfp1>::to(gfp1& result)
|
||||
{
|
||||
to<1, GFP_MOD_SZ>(result);
|
||||
}
|
||||
|
||||
template<>
|
||||
void Square<gfp3>::to(gfp3& result)
|
||||
{
|
||||
to<3, 4>(result);
|
||||
}
|
||||
|
||||
template<>
|
||||
void Square<BitVec>::to(BitVec& result)
|
||||
{
|
||||
|
||||
@@ -31,6 +31,8 @@ public:
|
||||
void conditional_add(BitVector& conditions, Square& other,
|
||||
int offset);
|
||||
void to(U& result);
|
||||
template<int X, int L>
|
||||
void to(gfp_<X, L>& result);
|
||||
|
||||
void pack(octetStream& os) const;
|
||||
void unpack(octetStream& os);
|
||||
|
||||
@@ -283,7 +283,8 @@ Z2<K> Z2<K>::operator>>(int i) const
|
||||
{
|
||||
Z2<K> res;
|
||||
int n_byte_shift = i / 8;
|
||||
memcpy(res.a, (char*)a + n_byte_shift, N_BYTES - n_byte_shift);
|
||||
if (N_BYTES - n_byte_shift > 0)
|
||||
memcpy(res.a, (char*)a + n_byte_shift, N_BYTES - n_byte_shift);
|
||||
int n_inside_shift = i % 8;
|
||||
if (n_inside_shift > 0)
|
||||
mpn_rshift(res.a, res.a, N_WORDS, n_inside_shift);
|
||||
|
||||
13
Math/Z2k.hpp
13
Math/Z2k.hpp
@@ -147,4 +147,17 @@ ostream& operator<<(ostream& o, const Z2<K>& x)
|
||||
return o;
|
||||
}
|
||||
|
||||
template<int K>
|
||||
istream& operator>>(istream& i, SignedZ2<K>& x)
|
||||
{
|
||||
auto& tmp = bigint::tmp;
|
||||
i >> tmp;
|
||||
if (tmp.numBits() > K + 1)
|
||||
throw runtime_error(
|
||||
tmp.get_str() + " out of range for signed " + to_string(K)
|
||||
+ "-bit numbers");
|
||||
x = tmp;
|
||||
return i;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -262,6 +262,8 @@ typedef gfp_<0, GFP_MOD_SZ> gfp;
|
||||
typedef gfp_<1, GFP_MOD_SZ> gfp1;
|
||||
// enough for Brain protocol with 64-bit computation and 40-bit security
|
||||
typedef gfp_<2, 4> gfp2;
|
||||
// for OT-based ECDSA
|
||||
typedef gfp_<3, 4> gfp3;
|
||||
|
||||
void to_signed_bigint(bigint& ans,const gfp& x);
|
||||
|
||||
|
||||
@@ -110,7 +110,7 @@ inline mp_limb_t mpn_add_n_with_carry(mp_limb_t* res, const mp_limb_t* x, const
|
||||
char carry = 0;
|
||||
for (int i = 0; i < n; i++)
|
||||
#if defined(__clang__)
|
||||
#if __clang_major__ < 8 || defined(__APPLE__)
|
||||
#if __clang_major__ < 8 || (defined(__APPLE__) && __clang_major__ < 11)
|
||||
carry = __builtin_ia32_addcarry_u64(carry, x[i], y[i], (unsigned long long*)&res[i]);
|
||||
#else
|
||||
carry = __builtin_ia32_addcarryx_u64(carry, x[i], y[i], (unsigned long long*)&res[i]);
|
||||
|
||||
@@ -97,12 +97,20 @@ Names::Names(ez::ezOptionParser& opt, int argc, const char** argv,
|
||||
1, // Required?
|
||||
1, // Number of args expected.
|
||||
0, // Delimiter if expecting multiple args.
|
||||
"This player's number", // Help description.
|
||||
"This player's number (required)", // Help description.
|
||||
"-p", // Flag token.
|
||||
"--player" // Flag token.
|
||||
);
|
||||
opt.parse(argc, argv);
|
||||
opt.get("-p")->getInt(player_no);
|
||||
vector<string> missing;
|
||||
if (not opt.gotRequired(missing))
|
||||
{
|
||||
string usage;
|
||||
opt.getUsage(usage);
|
||||
cerr << usage;
|
||||
exit(1);
|
||||
}
|
||||
global_server = network_opts.start_networking(*this, player_no);
|
||||
}
|
||||
|
||||
@@ -123,7 +131,7 @@ void Names::setup_names(const char *servername, int my_port)
|
||||
set_up_client_socket(socket_num, servername, pn);
|
||||
send(socket_num, (octet*)&player_no, sizeof(player_no));
|
||||
#ifdef DEBUG_NETWORKING
|
||||
fprintf(stderr, "Sent %d to %s:%d\n", player_no, servername, pn);
|
||||
cerr << "Sent " << player_no << " to " << servername << ":" << pn << endl;
|
||||
#endif
|
||||
|
||||
int inst=-1; // wait until instruction to start.
|
||||
@@ -338,6 +346,14 @@ void MultiPlayer<T>::send_all(const octetStream& o,bool donthash) const
|
||||
}
|
||||
|
||||
|
||||
void Player::receive_all(vector<octetStream>& os) const
|
||||
{
|
||||
for (int j = 0; j < num_players(); j++)
|
||||
if (j != my_num())
|
||||
receive_player(j, os[j], true);
|
||||
}
|
||||
|
||||
|
||||
void Player::receive_player(int i,octetStream& o,bool donthash) const
|
||||
{
|
||||
#ifdef VERBOSE_COMM
|
||||
@@ -345,6 +361,7 @@ void Player::receive_player(int i,octetStream& o,bool donthash) const
|
||||
#endif
|
||||
TimeScope ts(timer);
|
||||
receive_player_no_stats(i, o);
|
||||
comm_stats["Receiving directly"].add(o, ts);
|
||||
if (!donthash)
|
||||
{ blk_SHA1_Update(&ctx,o.get_data(),o.get_length()); }
|
||||
}
|
||||
@@ -627,12 +644,14 @@ void RealTwoPartyPlayer::receive(octetStream& o) const
|
||||
TimeScope ts(timer);
|
||||
o.reset_write_head();
|
||||
o.Receive(socket);
|
||||
comm_stats["Receiving one-to-one"].add(o, ts);
|
||||
}
|
||||
|
||||
void VirtualTwoPartyPlayer::receive(octetStream& o) const
|
||||
{
|
||||
TimeScope ts(timer);
|
||||
P.receive_player_no_stats(other_player, o);
|
||||
comm_stats["Receiving one-to-one"].add(o, ts);
|
||||
}
|
||||
|
||||
void RealTwoPartyPlayer::send_receive_player(vector<octetStream>& o) const
|
||||
@@ -688,6 +707,8 @@ void TwoPartyPlayer::Broadcast_Receive(vector<octetStream>& o,
|
||||
CommStats& CommStats::operator +=(const CommStats& other)
|
||||
{
|
||||
data += other.data;
|
||||
rounds += other.rounds;
|
||||
timer += other.timer;
|
||||
return *this;
|
||||
}
|
||||
|
||||
@@ -698,6 +719,13 @@ NamedCommStats& NamedCommStats::operator +=(const NamedCommStats& other)
|
||||
return *this;
|
||||
}
|
||||
|
||||
NamedCommStats NamedCommStats::operator +(const NamedCommStats& other) const
|
||||
{
|
||||
auto res = *this;
|
||||
res += other;
|
||||
return res;
|
||||
}
|
||||
|
||||
CommStats& CommStats::operator -=(const CommStats& other)
|
||||
{
|
||||
data -= other.data;
|
||||
@@ -722,13 +750,15 @@ size_t NamedCommStats::total_data()
|
||||
return res;
|
||||
}
|
||||
|
||||
void NamedCommStats::print()
|
||||
void NamedCommStats::print(bool newline)
|
||||
{
|
||||
for (auto it = begin(); it != end(); it++)
|
||||
if (it->second.data)
|
||||
cerr << it->first << " " << 1e-6 * it->second.data << " MB in "
|
||||
<< it->second.rounds << " rounds, taking " << it->second.timer.elapsed()
|
||||
<< " seconds" << endl;
|
||||
if (size() and newline)
|
||||
cerr << endl;
|
||||
}
|
||||
|
||||
template class MultiPlayer<int>;
|
||||
|
||||
@@ -91,6 +91,7 @@ struct CommStats
|
||||
Timer timer;
|
||||
CommStats() : data(0), rounds(0) {}
|
||||
Timer& add(const octetStream& os) { data += os.get_length(); rounds++; return timer; }
|
||||
void add(const octetStream& os, const TimeScope& scope) { add(os) += scope; }
|
||||
CommStats& operator+=(const CommStats& other);
|
||||
CommStats& operator-=(const CommStats& other);
|
||||
};
|
||||
@@ -99,9 +100,10 @@ class NamedCommStats : public map<string, CommStats>
|
||||
{
|
||||
public:
|
||||
NamedCommStats& operator+=(const NamedCommStats& other);
|
||||
NamedCommStats operator+(const NamedCommStats& other) const;
|
||||
NamedCommStats operator-(const NamedCommStats& other) const;
|
||||
size_t total_data();
|
||||
void print();
|
||||
void print(bool newline = false);
|
||||
#ifdef VERBOSE_COMM
|
||||
CommStats& operator[](const string& name)
|
||||
{
|
||||
@@ -160,6 +162,7 @@ public:
|
||||
virtual void send_all(const octetStream& o,bool donthash=false) const = 0;
|
||||
void send_to(int player,const octetStream& o,bool donthash=false) const;
|
||||
virtual void send_to_no_stats(int player,const octetStream& o) const = 0;
|
||||
void receive_all(vector<octetStream>& os) const;
|
||||
void receive_player(int i,octetStream& o,bool donthash=false) const;
|
||||
virtual void receive_player_no_stats(int i,octetStream& o) const = 0;
|
||||
virtual void receive_player(int i,FlexBuffer& buffer) const;
|
||||
|
||||
@@ -83,84 +83,40 @@ ServerSocket::~ServerSocket()
|
||||
|
||||
void ServerSocket::accept_clients()
|
||||
{
|
||||
map<int, sockaddr> unassigned_sockets;
|
||||
|
||||
while (true)
|
||||
{
|
||||
fd_set readfds;
|
||||
FD_ZERO(&readfds);
|
||||
int nfds = main_socket;
|
||||
FD_SET(main_socket, &readfds);
|
||||
for (auto &socket : unassigned_sockets)
|
||||
{
|
||||
FD_SET(socket.first, &readfds);
|
||||
nfds = max(socket.first, nfds);
|
||||
}
|
||||
struct sockaddr dest;
|
||||
memset(&dest, 0, sizeof(dest)); /* zero the struct before filling the fields */
|
||||
int socksize = sizeof(dest);
|
||||
int consocket = accept(main_socket, (struct sockaddr *)&dest, (socklen_t*) &socksize);
|
||||
if (consocket<0) { error("set_up_socket:accept"); }
|
||||
|
||||
select(nfds + 1, &readfds, 0, 0, 0);
|
||||
|
||||
if (FD_ISSET(main_socket, &readfds))
|
||||
{
|
||||
struct sockaddr dest;
|
||||
memset(&dest, 0, sizeof(dest)); /* zero the struct before filling the fields */
|
||||
int socksize = sizeof(dest);
|
||||
int consocket = accept(main_socket, (struct sockaddr*) &dest,
|
||||
(socklen_t*) &socksize);
|
||||
if (consocket < 0)
|
||||
error("set_up_socket:accept");
|
||||
unassigned_sockets[consocket] = dest;
|
||||
int client_id;
|
||||
try
|
||||
{
|
||||
receive(consocket, (unsigned char*)&client_id, sizeof(client_id));
|
||||
}
|
||||
catch (closed_connection&)
|
||||
{
|
||||
#ifdef DEBUG_NETWORKING
|
||||
auto &conn = *(sockaddr_in*) &dest;
|
||||
fprintf(stderr, "new client on %s:%d\n", inet_ntoa(conn.sin_addr),
|
||||
ntohs(conn.sin_port));
|
||||
auto& conn = *(sockaddr_in*)&dest;
|
||||
cerr << "client on " << inet_ntoa(conn.sin_addr) << ":"
|
||||
<< ntohs(conn.sin_port) << " left without identification"
|
||||
<< endl;
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
vector<int> processed_sockets;
|
||||
for (auto &socket : unassigned_sockets)
|
||||
{
|
||||
int consocket = socket.first;
|
||||
if (FD_ISSET(consocket, &readfds))
|
||||
{
|
||||
try
|
||||
{
|
||||
int client_id;
|
||||
receive(consocket, (unsigned char*) &client_id,
|
||||
sizeof(client_id));
|
||||
|
||||
data_signal.lock();
|
||||
clients[client_id] = consocket;
|
||||
data_signal.broadcast();
|
||||
data_signal.unlock();
|
||||
|
||||
#ifdef DEBUG_NETWORKING
|
||||
auto &conn = *(sockaddr_in*) &socket.second;
|
||||
fprintf(stderr, "client id %d on %s:%d\n", client_id,
|
||||
inet_ntoa(conn.sin_addr), ntohs(conn.sin_port));
|
||||
#endif
|
||||
data_signal.lock();
|
||||
clients[client_id] = consocket;
|
||||
data_signal.broadcast();
|
||||
data_signal.unlock();
|
||||
|
||||
#ifdef __APPLE__
|
||||
int flags = fcntl(consocket, F_GETFL, 0);
|
||||
int fl = fcntl(consocket, F_SETFL, O_NONBLOCK | flags);
|
||||
if (fl < 0)
|
||||
error("set non-blocking");
|
||||
int flags = fcntl(consocket, F_GETFL, 0);
|
||||
int fl = fcntl(consocket, F_SETFL, O_NONBLOCK | flags);
|
||||
if (fl < 0)
|
||||
error("set non-blocking");
|
||||
#endif
|
||||
}
|
||||
catch (closed_connection&)
|
||||
{
|
||||
#ifdef DEBUG_NETWORKING
|
||||
auto &conn = *(sockaddr_in*) &socket.second;
|
||||
cerr << "client on " << inet_ntoa(conn.sin_addr) << ":"
|
||||
<< ntohs(conn.sin_port) << " left without identification"
|
||||
<< endl;
|
||||
#endif
|
||||
close_client_socket(consocket);
|
||||
}
|
||||
processed_sockets.push_back(consocket);
|
||||
}
|
||||
}
|
||||
for (int socket : processed_sockets)
|
||||
unassigned_sockets.erase(socket);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -116,7 +116,7 @@ void BaseOT::exec_base(bool new_receiver_inputs)
|
||||
{
|
||||
if (new_receiver_inputs)
|
||||
receiver_inputs[i + j] = G.get_uchar()&1;
|
||||
cs[j] = receiver_inputs[i + j];
|
||||
cs[j] = receiver_inputs[i + j].get();
|
||||
}
|
||||
receiver_rsgen(&receiver, Rs_pack[0], cs);
|
||||
os[0].store_bytes(Rs_pack[0], sizeof(Rs_pack[0]));
|
||||
@@ -293,7 +293,7 @@ void FakeOT::exec_base(bool new_receiver_inputs)
|
||||
{
|
||||
for (int j = 0; j < 2; j++)
|
||||
bv[j].unpack(os[1]);
|
||||
receiver_outputs[i] = bv[receiver_inputs[i]];
|
||||
receiver_outputs[i] = bv[receiver_inputs[i].get()];
|
||||
}
|
||||
|
||||
set_seeds();
|
||||
|
||||
@@ -30,7 +30,7 @@ void send_if_ot_receiver(TwoPartyPlayer* P, vector<octetStream>& os, OT_ROLE rol
|
||||
class BaseOT
|
||||
{
|
||||
public:
|
||||
vector<int> receiver_inputs;
|
||||
BitVector receiver_inputs;
|
||||
vector< vector<BitVector> > sender_inputs;
|
||||
vector<BitVector> receiver_outputs;
|
||||
TwoPartyPlayer* P;
|
||||
@@ -63,7 +63,7 @@ public:
|
||||
|
||||
int length() { return ot_length; }
|
||||
|
||||
void set_receiver_inputs(const vector<int>& new_inputs)
|
||||
void set_receiver_inputs(const BitVector& new_inputs)
|
||||
{
|
||||
if ((int)new_inputs.size() != nOT)
|
||||
throw invalid_length();
|
||||
@@ -72,7 +72,7 @@ public:
|
||||
|
||||
void set_receiver_inputs(int128 inputs)
|
||||
{
|
||||
vector<int> new_inputs(128);
|
||||
BitVector new_inputs(128);
|
||||
for (int i = 0; i < 128; i++)
|
||||
new_inputs[i] = (inputs >> i).get_lower() & 1;
|
||||
set_receiver_inputs(new_inputs);
|
||||
@@ -81,6 +81,7 @@ public:
|
||||
// do the OTs -- generate fresh random choice bits by default
|
||||
virtual void exec_base(bool new_receiver_inputs=true);
|
||||
// use PRG to get the next ot_length bits
|
||||
void set_seeds();
|
||||
void extend_length();
|
||||
void check();
|
||||
|
||||
@@ -90,8 +91,6 @@ protected:
|
||||
|
||||
bool is_sender() { return (bool) (ot_role & SENDER); }
|
||||
bool is_receiver() { return (bool) (ot_role & RECEIVER); }
|
||||
|
||||
void set_seeds();
|
||||
};
|
||||
|
||||
class FakeOT : public BaseOT
|
||||
|
||||
@@ -702,4 +702,5 @@ BMS
|
||||
XXXX(Matrix<gf2n_short_square>, gf2n_short)
|
||||
XXXX(Matrix<Square<gf2n_long>>, gf2n_long)
|
||||
XXXX(Matrix<Square<gfp1>>, gfp1)
|
||||
XXXX(Matrix<Square<gfp3>>, gfp3)
|
||||
XXXX(Matrix<BitDiagonal>, BitVec)
|
||||
|
||||
@@ -142,7 +142,7 @@ public:
|
||||
BitMatrix() {}
|
||||
BitMatrix(int length);
|
||||
|
||||
__m128i operator[](int i) { return squares[i / 128].rows[i % 128]; }
|
||||
__m128i& operator[](int i) { return squares[i / 128].rows[i % 128]; }
|
||||
|
||||
void resize(int length);
|
||||
int size();
|
||||
|
||||
@@ -26,81 +26,15 @@ MascotParams::MascotParams()
|
||||
generateMACs = true;
|
||||
amplify = true;
|
||||
check = true;
|
||||
correlation_check = true;
|
||||
generateBits = false;
|
||||
use_extension = true;
|
||||
fewer_rounds = false;
|
||||
fiat_shamir = false;
|
||||
timerclear(&start);
|
||||
}
|
||||
|
||||
void MascotParams::set_passive()
|
||||
{
|
||||
generateMACs = amplify = check = false;
|
||||
}
|
||||
|
||||
template<> gf2n_long MascotParams::get_mac_key()
|
||||
{
|
||||
return mac_key2l;
|
||||
}
|
||||
|
||||
template<> gf2n_short MascotParams::get_mac_key()
|
||||
{
|
||||
return mac_key2s;
|
||||
}
|
||||
|
||||
template<> gfp1 MascotParams::get_mac_key()
|
||||
{
|
||||
return mac_keyp;
|
||||
}
|
||||
|
||||
template<> Z2<48> MascotParams::get_mac_key()
|
||||
{
|
||||
return mac_keyz;
|
||||
}
|
||||
|
||||
template<> Z2<64> MascotParams::get_mac_key()
|
||||
{
|
||||
return mac_keyz;
|
||||
}
|
||||
|
||||
template<> Z2<40> MascotParams::get_mac_key()
|
||||
{
|
||||
return mac_keyz;
|
||||
}
|
||||
|
||||
template<> Z2<32> MascotParams::get_mac_key()
|
||||
{
|
||||
return mac_keyz;
|
||||
}
|
||||
|
||||
template<> BitVec MascotParams::get_mac_key()
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
template<> void MascotParams::set_mac_key(gf2n_long key)
|
||||
{
|
||||
mac_key2l = key;
|
||||
}
|
||||
|
||||
template<> void MascotParams::set_mac_key(gf2n_short key)
|
||||
{
|
||||
mac_key2s = key;
|
||||
}
|
||||
|
||||
template<> void MascotParams::set_mac_key(gfp1 key)
|
||||
{
|
||||
mac_keyp = key;
|
||||
}
|
||||
|
||||
template<> void MascotParams::set_mac_key(Z2<64> key)
|
||||
{
|
||||
mac_keyz = key;
|
||||
}
|
||||
|
||||
template<> void MascotParams::set_mac_key(Z2<48> key)
|
||||
{
|
||||
mac_keyz = key;
|
||||
}
|
||||
|
||||
template<> void MascotParams::set_mac_key(Z2<40> key)
|
||||
{
|
||||
mac_keyz = key;
|
||||
generateMACs = amplify = check = correlation_check = false;
|
||||
}
|
||||
|
||||
@@ -68,6 +68,8 @@ protected:
|
||||
|
||||
SeededPRNG share_prg;
|
||||
|
||||
mac_key_type mac_key;
|
||||
|
||||
void start_progress();
|
||||
void print_progress(int k);
|
||||
|
||||
@@ -101,8 +103,11 @@ public:
|
||||
vector<PlainTriple<open_type, N_AMPLIFY>> preampTriples;
|
||||
vector<array<open_type, 3>> plainTriples;
|
||||
|
||||
OTTripleGenerator(OTTripleSetup& setup, const Names& names,
|
||||
typename T::MAC_Check* MC;
|
||||
|
||||
OTTripleGenerator(const OTTripleSetup& setup, const Names& names,
|
||||
int thread_num, int nTriples, int nloops, MascotParams& machine,
|
||||
mac_key_type mac_key,
|
||||
Player* parentPlayer = 0);
|
||||
~OTTripleGenerator();
|
||||
|
||||
@@ -113,7 +118,10 @@ public:
|
||||
|
||||
void run_multipliers(MultJob job);
|
||||
|
||||
mac_key_type get_mac_key() const { return mac_key; }
|
||||
|
||||
size_t data_sent();
|
||||
NamedCommStats comm_stats();
|
||||
};
|
||||
|
||||
template<class T>
|
||||
@@ -130,8 +138,9 @@ public:
|
||||
vector< ShareTriple_<sacri_type, mac_key_type, 2> > uncheckedTriples;
|
||||
vector<InputTuple<Share<sacri_type>>> inputs;
|
||||
|
||||
NPartyTripleGenerator(OTTripleSetup& setup, const Names& names,
|
||||
NPartyTripleGenerator(const OTTripleSetup& setup, const Names& names,
|
||||
int thread_num, int nTriples, int nloops, MascotParams& machine,
|
||||
mac_key_type mac_key,
|
||||
Player* parentPlayer = 0);
|
||||
virtual ~NPartyTripleGenerator() {}
|
||||
|
||||
@@ -159,8 +168,9 @@ class MascotTripleGenerator : public NPartyTripleGenerator<T>
|
||||
public:
|
||||
vector<T> bits;
|
||||
|
||||
MascotTripleGenerator(OTTripleSetup& setup, const Names& names,
|
||||
MascotTripleGenerator(const OTTripleSetup& setup, const Names& names,
|
||||
int thread_num, int nTriples, int nloops, MascotParams& machine,
|
||||
mac_key_type mac_key,
|
||||
Player* parentPlayer = 0);
|
||||
};
|
||||
|
||||
@@ -181,8 +191,9 @@ class Spdz2kTripleGenerator : public NPartyTripleGenerator<T>
|
||||
U& MC, PRNG& G);
|
||||
|
||||
public:
|
||||
Spdz2kTripleGenerator(OTTripleSetup& setup, const Names& names,
|
||||
Spdz2kTripleGenerator(const OTTripleSetup& setup, const Names& names,
|
||||
int thread_num, int nTriples, int nloops, MascotParams& machine,
|
||||
mac_key_type mac_key,
|
||||
Player* parentPlayer = 0);
|
||||
|
||||
void generateTriples();
|
||||
@@ -199,4 +210,15 @@ size_t OTTripleGenerator<T>::data_sent()
|
||||
return res;
|
||||
}
|
||||
|
||||
template<class T>
|
||||
NamedCommStats OTTripleGenerator<T>::comm_stats()
|
||||
{
|
||||
NamedCommStats res;
|
||||
if (parentPlayer != &globalPlayer)
|
||||
res = globalPlayer.comm_stats;
|
||||
for (auto& player : players)
|
||||
res += player->comm_stats;
|
||||
return res;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -5,27 +5,14 @@
|
||||
|
||||
#include "OT/OTExtensionWithMatrix.h"
|
||||
#include "OT/OTMultiplier.h"
|
||||
#include "Math/gfp.h"
|
||||
#include "Protocols/Share.h"
|
||||
#include "Protocols/SemiShare.h"
|
||||
#include "Protocols/Semi2kShare.h"
|
||||
#include "Protocols/Spdz2kShare.h"
|
||||
#include "Math/operators.h"
|
||||
#include "Tools/Subroutines.h"
|
||||
#include "Protocols/MAC_Check.h"
|
||||
#include "Protocols/Spdz2kPrep.h"
|
||||
#include "GC/SemiSecret.h"
|
||||
|
||||
#include "OT/Triple.hpp"
|
||||
#include "OT/Rectangle.hpp"
|
||||
#include "OT/OTMultiplier.hpp"
|
||||
#include "Protocols/MAC_Check.hpp"
|
||||
#include "Protocols/SemiMC.h"
|
||||
#include "Protocols/MascotPrep.hpp"
|
||||
#include "Protocols/ReplicatedInput.hpp"
|
||||
#include "Protocols/SemiInput.hpp"
|
||||
#include "Processor/Input.hpp"
|
||||
#include "Math/Z2k.hpp"
|
||||
|
||||
#include <sstream>
|
||||
#include <fstream>
|
||||
@@ -43,44 +30,46 @@ void* run_ot_thread(void* ptr)
|
||||
* N.B. setup must not be stored as it will be used by other threads
|
||||
*/
|
||||
template<class T>
|
||||
NPartyTripleGenerator<T>::NPartyTripleGenerator(OTTripleSetup& setup,
|
||||
NPartyTripleGenerator<T>::NPartyTripleGenerator(const OTTripleSetup& setup,
|
||||
const Names& names, int thread_num, int _nTriples, int nloops,
|
||||
MascotParams& machine, Player* parentPlayer) :
|
||||
MascotParams& machine, mac_key_type mac_key, Player* parentPlayer) :
|
||||
OTTripleGenerator<T>(setup, names, thread_num, _nTriples, nloops,
|
||||
machine, parentPlayer)
|
||||
machine, mac_key, parentPlayer)
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
MascotTripleGenerator<T>::MascotTripleGenerator(OTTripleSetup& setup,
|
||||
MascotTripleGenerator<T>::MascotTripleGenerator(const OTTripleSetup& setup,
|
||||
const Names& names, int thread_num, int _nTriples, int nloops,
|
||||
MascotParams& machine, Player* parentPlayer) :
|
||||
MascotParams& machine, mac_key_type mac_key, Player* parentPlayer) :
|
||||
NPartyTripleGenerator<T>(setup, names, thread_num, _nTriples, nloops,
|
||||
machine, parentPlayer)
|
||||
machine, mac_key, parentPlayer)
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
Spdz2kTripleGenerator<T>::Spdz2kTripleGenerator(OTTripleSetup& setup,
|
||||
Spdz2kTripleGenerator<T>::Spdz2kTripleGenerator(const OTTripleSetup& setup,
|
||||
const Names& names, int thread_num, int _nTriples, int nloops,
|
||||
MascotParams& machine, Player* parentPlayer) :
|
||||
MascotParams& machine, mac_key_type mac_key, Player* parentPlayer) :
|
||||
NPartyTripleGenerator<T>(setup, names, thread_num, _nTriples, nloops,
|
||||
machine, parentPlayer)
|
||||
machine, mac_key, parentPlayer)
|
||||
{
|
||||
}
|
||||
|
||||
template<class T>
|
||||
OTTripleGenerator<T>::OTTripleGenerator(OTTripleSetup& setup,
|
||||
OTTripleGenerator<T>::OTTripleGenerator(const OTTripleSetup& setup,
|
||||
const Names& names, int thread_num, int _nTriples, int nloops,
|
||||
MascotParams& machine, Player* parentPlayer) :
|
||||
MascotParams& machine, mac_key_type mac_key, Player* parentPlayer) :
|
||||
globalPlayer(parentPlayer ? *parentPlayer : *new PlainPlayer(names,
|
||||
- thread_num * names.num_players() * names.num_players())),
|
||||
parentPlayer(parentPlayer),
|
||||
thread_num(thread_num),
|
||||
mac_key(mac_key),
|
||||
my_num(setup.get_my_num()),
|
||||
nloops(nloops),
|
||||
nparties(setup.get_nparties()),
|
||||
machine(machine)
|
||||
machine(machine),
|
||||
MC(0)
|
||||
{
|
||||
nTriplesPerLoop = DIV_CEIL(_nTriples, nloops);
|
||||
nTriples = nTriplesPerLoop * nloops;
|
||||
@@ -208,7 +197,6 @@ void NPartyTripleGenerator<W>::generateInputs(int player)
|
||||
{
|
||||
typedef open_type T;
|
||||
|
||||
auto& machine = this->machine;
|
||||
auto& nTriplesPerLoop = this->nTriplesPerLoop;
|
||||
auto& valueBits = this->valueBits;
|
||||
auto& share_prg = this->share_prg;
|
||||
@@ -235,7 +223,7 @@ void NPartyTripleGenerator<W>::generateInputs(int player)
|
||||
GlobalPRNG G(globalPlayer);
|
||||
Share<T> check_sum;
|
||||
inputs.resize(toCheck);
|
||||
auto mac_key = machine.template get_mac_key<mac_key_type>();
|
||||
auto mac_key = this->get_mac_key();
|
||||
SemiInput<SemiShare<T>> input(0, globalPlayer);
|
||||
input.reset_all(globalPlayer);
|
||||
vector<T> secrets(toCheck);
|
||||
@@ -289,7 +277,7 @@ void MascotTripleGenerator<T>::generateBitsGf2n()
|
||||
bits.resize(nBitsToCheck);
|
||||
vector<T> to_open(1);
|
||||
vector<typename T::clear> opened(1);
|
||||
MAC_Check_<T> MC(this->machine.template get_mac_key<typename T::clear>());
|
||||
MAC_Check_<T> MC(this->get_mac_key());
|
||||
|
||||
this->start_progress();
|
||||
|
||||
@@ -313,7 +301,7 @@ void MascotTripleGenerator<T>::generateBitsGf2n()
|
||||
typename T::clear r;
|
||||
for (int j = 0; j < nBitsToCheck; j++)
|
||||
{
|
||||
auto mac_sum = valueBits[0].get_bit(j) ? MC.get_alphai() : 0;
|
||||
auto mac_sum = valueBits[0].get_bit(j) ? this->get_mac_key() : 0;
|
||||
for (int i = 0; i < this->nparties-1; i++)
|
||||
mac_sum += this->ot_multipliers[i]->macs[0][j];
|
||||
bits[j].set_share(valueBits[0].get_bit(j));
|
||||
@@ -352,6 +340,13 @@ void MascotTripleGenerator<Share<gfp1>>::generateBits()
|
||||
generateTriples();
|
||||
}
|
||||
|
||||
template<>
|
||||
inline
|
||||
void MascotTripleGenerator<Share<gfp3>>::generateBits()
|
||||
{
|
||||
generateTriples();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Spdz2kTripleGenerator<T>::generateTriples()
|
||||
{
|
||||
@@ -360,7 +355,6 @@ void Spdz2kTripleGenerator<T>::generateTriples()
|
||||
auto& uncheckedTriples = this->uncheckedTriples;
|
||||
|
||||
auto& timers = this->timers;
|
||||
auto& machine = this->machine;
|
||||
auto& nTriplesPerLoop = this->nTriplesPerLoop;
|
||||
auto& valueBits = this->valueBits;
|
||||
auto& share_prg = this->share_prg;
|
||||
@@ -382,7 +376,7 @@ void Spdz2kTripleGenerator<T>::generateTriples()
|
||||
vector< PlainTriple_<Z2<K + 2 * S>, Z2<K + S>, 2> > amplifiedTriples(nTriplesPerLoop);
|
||||
uncheckedTriples.resize(nTriplesPerLoop);
|
||||
MAC_Check_Z2k<Z2<K + 2 * S>, Z2<S>, Z2<K + S>, Share<Z2<K + 2 * S>> > MC(
|
||||
machine.template get_mac_key<Z2<S> >());
|
||||
this->get_mac_key());
|
||||
|
||||
this->start_progress();
|
||||
|
||||
@@ -455,7 +449,7 @@ void Spdz2kTripleGenerator<T>::generateTriples()
|
||||
// get piggy-backed random value
|
||||
Z2<K + 2 * S> r_share = b_padded_bits.get_ptr_to_byte(nTriplesPerLoop, Z2<K + 2 * S>::N_BYTES);
|
||||
Z2<K + 2 * S> r_mac;
|
||||
r_mac.mul(r_share, this->machine.template get_mac_key<Z2<S>>());
|
||||
r_mac.mul(r_share, this->get_mac_key());
|
||||
for (int i = 0; i < this->nparties-1; i++)
|
||||
r_mac += (ot_multipliers[i])->macs.at(1).at(nTriplesPerLoop);
|
||||
Share<Z2<K + 2 * S>> r;
|
||||
@@ -563,16 +557,17 @@ void MascotTripleGenerator<U>::generateTriples()
|
||||
valueBits[2*i].resize(field_size * nPreampTriplesPerLoop);
|
||||
valueBits[1].resize(field_size * nTriplesPerLoop);
|
||||
vector< PlainTriple<T,2> > amplifiedTriples;
|
||||
MAC_Check<T> MC(machine.template get_mac_key<T>());
|
||||
MAC_Check<T> MC(this->get_mac_key());
|
||||
|
||||
if (machine.amplify)
|
||||
preampTriples.resize(nTriplesPerLoop);
|
||||
if (machine.generateMACs)
|
||||
{
|
||||
amplifiedTriples.resize(nTriplesPerLoop);
|
||||
uncheckedTriples.resize(nTriplesPerLoop);
|
||||
}
|
||||
|
||||
uncheckedTriples.resize(nTriplesPerLoop);
|
||||
|
||||
this->start_progress();
|
||||
|
||||
for (int k = 0; k < nloops; k++)
|
||||
@@ -581,10 +576,15 @@ void MascotTripleGenerator<U>::generateTriples()
|
||||
|
||||
if (machine.amplify)
|
||||
{
|
||||
octet seed[SEED_SIZE];
|
||||
Create_Random_Seed(seed, globalPlayer, SEED_SIZE);
|
||||
PRNG G;
|
||||
G.SetSeed(seed);
|
||||
if (machine.fiat_shamir and nparties == 2)
|
||||
ot_multipliers[0]->otCorrelator.common_seed(G);
|
||||
else
|
||||
{
|
||||
octet seed[SEED_SIZE];
|
||||
Create_Random_Seed(seed, globalPlayer, SEED_SIZE);
|
||||
G.SetSeed(seed);
|
||||
}
|
||||
for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++)
|
||||
{
|
||||
PlainTriple<T,2> triple;
|
||||
@@ -598,12 +598,16 @@ void MascotTripleGenerator<U>::generateTriples()
|
||||
triple.output(outputFile);
|
||||
timers["Writing"].stop();
|
||||
}
|
||||
else
|
||||
for (int i = 0; i < 3; i++)
|
||||
uncheckedTriples[iTriple].byIndex(i, 0).set_share(triple.byIndex(i, 0));
|
||||
}
|
||||
|
||||
if (machine.generateMACs)
|
||||
{
|
||||
for (int iTriple = 0; iTriple < nTriplesPerLoop; iTriple++)
|
||||
amplifiedTriples[iTriple].to(valueBits, iTriple);
|
||||
amplifiedTriples[iTriple].to(valueBits, iTriple,
|
||||
machine.check ? 2 : 1);
|
||||
|
||||
for (int i = 0; i < nparties-1; i++)
|
||||
ot_multipliers[i]->inbox.push({});
|
||||
@@ -625,7 +629,7 @@ void MascotTripleGenerator<U>::generateTriples()
|
||||
|
||||
if (machine.check)
|
||||
{
|
||||
sacrifice(uncheckedTriples, MC, G);
|
||||
sacrifice(uncheckedTriples, this->MC ? *this->MC : MC, G);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -259,7 +259,7 @@ void naive_transpose64(vector<BitVector>& output, const vector<BitVector>& input
|
||||
}
|
||||
|
||||
|
||||
OTExtension::OTExtension(BaseOT& baseOT, TwoPartyPlayer* player,
|
||||
OTExtension::OTExtension(const BaseOT& baseOT, TwoPartyPlayer* player,
|
||||
bool passive) : player(player)
|
||||
{
|
||||
nbaseOTs = baseOT.nOT;
|
||||
|
||||
@@ -30,7 +30,7 @@ public:
|
||||
vector<BitVector> receiverOutput;
|
||||
map<string,long long> times;
|
||||
|
||||
OTExtension(BaseOT& baseOT, TwoPartyPlayer* player, bool passive);
|
||||
OTExtension(const BaseOT& baseOT, TwoPartyPlayer* player, bool passive);
|
||||
|
||||
OTExtension(int nbaseOTs, int baseLength,
|
||||
int nloops, int nsubloops,
|
||||
|
||||
@@ -308,7 +308,8 @@ void OTExtensionWithMatrix::hash_outputs(int nOTs)
|
||||
}
|
||||
|
||||
template <class V>
|
||||
void OTExtensionWithMatrix::hash_outputs(int nOTs, vector<V>& senderOutput, V& receiverOutput)
|
||||
void OTExtensionWithMatrix::hash_outputs(int nOTs, vector<V>& senderOutput,
|
||||
V& receiverOutput, bool correlated)
|
||||
{
|
||||
//cout << "Hashing... " << flush;
|
||||
octetStream os, h_os(HASH_SIZE);
|
||||
@@ -341,7 +342,11 @@ void OTExtensionWithMatrix::hash_outputs(int nOTs, vector<V>& senderOutput, V& r
|
||||
for (int j = 0; j < 8; j++)
|
||||
{
|
||||
tmp[0][j] = senderOutputMatrices[0].squares[i_outer_input].rows[i_inner_input + j];
|
||||
tmp[1][j] = tmp[0][j] ^ baseReceiverInput.get_int128(0);
|
||||
if (correlated)
|
||||
tmp[1][j] = tmp[0][j] ^ baseReceiverInput.get_int128(0);
|
||||
else
|
||||
tmp[1][j] =
|
||||
senderOutputMatrices[1].squares[i_outer_input].rows[i_inner_input + j];
|
||||
}
|
||||
for (int j = 0; j < 2; j++)
|
||||
mmo.hashBlocks<T, 8>(
|
||||
@@ -366,17 +371,39 @@ void OTExtensionWithMatrix::hash_outputs(int nOTs, vector<V>& senderOutput, V& r
|
||||
|
||||
template <class U>
|
||||
template <class T>
|
||||
void OTCorrelator<U>::reduce_squares(unsigned int nTriples, vector<T>& output)
|
||||
void OTCorrelator<U>::reduce_squares(unsigned int nTriples, vector<T>& output, int start)
|
||||
{
|
||||
if (receiverOutputMatrix.squares.size() < nTriples)
|
||||
if (receiverOutputMatrix.squares.size() < nTriples + start)
|
||||
throw invalid_length();
|
||||
output.resize(nTriples);
|
||||
for (unsigned int j = 0; j < nTriples; j++)
|
||||
{
|
||||
receiverOutputMatrix.squares[j].sub(senderOutputMatrices[0].squares[j]).to(output[j]);
|
||||
receiverOutputMatrix.squares[j + start].sub(
|
||||
senderOutputMatrices[0].squares[j + start]).to(output[j]);
|
||||
}
|
||||
}
|
||||
|
||||
template <class U>
|
||||
void OTCorrelator<U>::common_seed(PRNG& G)
|
||||
{
|
||||
Slice<U> t1Slice(t1, 0, t1.squares.size());
|
||||
Slice<U> uSlice(u, 0, u.squares.size());
|
||||
|
||||
octetStream os;
|
||||
if (player->my_num())
|
||||
{
|
||||
t1Slice.pack(os);
|
||||
uSlice.pack(os);
|
||||
}
|
||||
else
|
||||
{
|
||||
uSlice.pack(os);
|
||||
t1Slice.pack(os);
|
||||
}
|
||||
auto hash = os.hash();
|
||||
G = PRNG(hash);
|
||||
}
|
||||
|
||||
octet* OTExtensionWithMatrix::get_receiver_output(int i)
|
||||
{
|
||||
return (octet*)&receiverOutputMatrix.squares[i/128].rows[i%128];
|
||||
@@ -515,34 +542,35 @@ template class OTCorrelator<BitMatrix>;
|
||||
#define Z(BM,GF) \
|
||||
template class OTCorrelator<BM>; \
|
||||
template void OTCorrelator<BM>::reduce_squares<GF>(unsigned int nTriples, \
|
||||
vector<GF>& output);
|
||||
vector<GF>& output, int);
|
||||
|
||||
#define ZZZZ(GF) \
|
||||
template void OTExtensionWithMatrix::print_post_correlate<GF>( \
|
||||
BitVector& newReceiverInput, int j, int offset, int sender); \
|
||||
|
||||
#define ZZZ(GF, M) Z(M, GF) \
|
||||
template void OTExtensionWithMatrix::hash_outputs(int, vector<M >&, M&);
|
||||
template void OTExtensionWithMatrix::hash_outputs(int, vector<M >&, M&, bool);
|
||||
|
||||
ZZZZ(gf2n_long)
|
||||
ZZZ(gf2n_short, Matrix<gf2n_short_square>)
|
||||
ZZZ(gf2n_long, Matrix<Square<gf2n_long>>)
|
||||
ZZZ(gfp1, Matrix<Square<gfp1>>)
|
||||
ZZZ(gfp3, Matrix<Square<gfp3>>)
|
||||
ZZZ(BitVec, Matrix<BitDiagonal>)
|
||||
|
||||
#undef XX
|
||||
#define XX(T,U,N,L) \
|
||||
template class OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >; \
|
||||
template void OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >::reduce_squares(unsigned int nTriples, \
|
||||
vector<U>& output); \
|
||||
vector<U>& output, int); \
|
||||
template void OTExtensionWithMatrix::hash_outputs(int, \
|
||||
std::vector<Matrix<Rectangle<Z2<N>, Z2<L> > >, std::allocator<Matrix<Rectangle<Z2<N>, Z2<L> > > > >&, \
|
||||
Matrix<Rectangle<Z2<N>, Z2<L> > >&);
|
||||
Matrix<Rectangle<Z2<N>, Z2<L> > >&, bool);
|
||||
|
||||
#undef X
|
||||
#define X(N,L) \
|
||||
template void OTCorrelator<Matrix<Rectangle<Z2<N>, Z2<L> > > >::reduce_squares(unsigned int nTriples, \
|
||||
vector<Z2kRectangle<N, L> >& output); \
|
||||
vector<Z2kRectangle<N, L> >& output, int); \
|
||||
XX(Z2<L>,Z2<N>,N,L)
|
||||
|
||||
//X(96, 160)
|
||||
|
||||
@@ -45,7 +45,9 @@ public:
|
||||
U& baseReceiverOutput);
|
||||
void correlate(int start, int slice, BitVector& newReceiverInput, bool useConstantBase, int repeat = 1);
|
||||
template <class T>
|
||||
void reduce_squares(unsigned int nTriples, vector<T>& output);
|
||||
void reduce_squares(unsigned int nTriples, vector<T>& output,
|
||||
int start = 0);
|
||||
void common_seed(PRNG& G);
|
||||
};
|
||||
|
||||
class OTExtensionWithMatrix : public OTCorrelator<BitMatrix>
|
||||
@@ -80,7 +82,8 @@ public:
|
||||
void transpose(int start, int slice);
|
||||
void expand_transposed();
|
||||
template <class V>
|
||||
void hash_outputs(int nOTs, vector<V>& senderOutput, V& receiverOutput);
|
||||
void hash_outputs(int nOTs, vector<V>& senderOutput, V& receiverOutput,
|
||||
bool correlated = true);
|
||||
|
||||
void print(BitVector& newReceiverInput, int i = 0);
|
||||
template <class T>
|
||||
|
||||
@@ -7,18 +7,11 @@
|
||||
|
||||
#include "OT/OTMultiplier.h"
|
||||
#include "OT/NPartyTripleGenerator.h"
|
||||
#include "OT/Rectangle.h"
|
||||
#include "Math/Z2k.h"
|
||||
#include "Math/BitVec.h"
|
||||
#include "Protocols/SemiShare.h"
|
||||
#include "Protocols/Semi2kShare.h"
|
||||
#include "Protocols/Spdz2kShare.h"
|
||||
#include "OT/BaseOT.h"
|
||||
|
||||
#include "OT/OTVole.hpp"
|
||||
#include "OT/Row.hpp"
|
||||
#include "OT/Rectangle.hpp"
|
||||
#include "Math/Z2k.hpp"
|
||||
#include "Math/Square.hpp"
|
||||
|
||||
#include <math.h>
|
||||
|
||||
@@ -31,7 +24,8 @@ OTMultiplier<T>::OTMultiplier(OTTripleGenerator<T>& generator,
|
||||
rot_ext(128, 128, 0, 1,
|
||||
generator.players[thread_num], generator.baseReceiverInput,
|
||||
generator.baseSenderInputs[thread_num],
|
||||
generator.baseReceiverOutputs[thread_num], BOTH, !generator.machine.check),
|
||||
generator.baseReceiverOutputs[thread_num], BOTH,
|
||||
!generator.machine.correlation_check),
|
||||
otCorrelator(0, 0, 0, 0, generator.players[thread_num], {}, {}, {}, BOTH, true)
|
||||
{
|
||||
this->thread = 0;
|
||||
@@ -89,7 +83,7 @@ OTMultiplier<T>::~OTMultiplier()
|
||||
template<class T>
|
||||
void OTMultiplier<T>::multiply()
|
||||
{
|
||||
keyBits.set(generator.machine.template get_mac_key<typename T::mac_key_type>());
|
||||
keyBits.set(generator.get_mac_key());
|
||||
rot_ext.extend(keyBits.size(), keyBits);
|
||||
this->outbox.push({});
|
||||
senderOutput.resize(keyBits.size());
|
||||
@@ -140,11 +134,6 @@ void OTMultiplier<W>::multiplyForTriples()
|
||||
{
|
||||
typedef typename W::Rectangle X;
|
||||
|
||||
// dummy input for OT correlator
|
||||
vector<BitVector> _;
|
||||
vector< vector<BitVector> > __;
|
||||
BitVector ___;
|
||||
|
||||
otCorrelator.resize(X::N_COLUMNS * generator.nPreampTriplesPerLoop);
|
||||
|
||||
rot_ext.resize(X::N_ROWS * generator.nPreampTriplesPerLoop + 2 * 128);
|
||||
@@ -161,8 +150,26 @@ void OTMultiplier<W>::multiplyForTriples()
|
||||
this->inbox.pop(job);
|
||||
BitVector aBits = generator.valueBits[0];
|
||||
//timers["Extension"].start();
|
||||
rot_ext.extend_correlated(aBits);
|
||||
rot_ext.hash_outputs(aBits.size(), baseSenderOutputs, baseReceiverOutput);
|
||||
if (generator.machine.use_extension)
|
||||
{
|
||||
rot_ext.extend_correlated(aBits);
|
||||
}
|
||||
else
|
||||
{
|
||||
BaseOT bot(aBits.size(), -1, generator.players[thread_num]);
|
||||
bot.set_receiver_inputs(aBits);
|
||||
bot.exec_base(false);
|
||||
for (size_t i = 0; i < aBits.size(); i++)
|
||||
{
|
||||
rot_ext.receiverOutputMatrix[i] =
|
||||
bot.receiver_outputs[i].get_int128(0).a;
|
||||
for (int j = 0; j < 2; j++)
|
||||
rot_ext.senderOutputMatrices[j][i] =
|
||||
bot.sender_inputs[i][j].get_int128(0).a;
|
||||
}
|
||||
}
|
||||
rot_ext.hash_outputs(aBits.size(), baseSenderOutputs,
|
||||
baseReceiverOutput, generator.machine.use_extension);
|
||||
//timers["Extension"].stop();
|
||||
|
||||
//timers["Correlation"].start();
|
||||
@@ -215,8 +222,6 @@ void MascotMultiplier<U>::after_correlation()
|
||||
{
|
||||
typedef typename U::open_type T;
|
||||
|
||||
this->auth_ot_ext.resize(
|
||||
this->generator.nPreampTriplesPerLoop * T::Square::N_COLUMNS);
|
||||
this->auth_ot_ext.set_role(BOTH);
|
||||
|
||||
this->otCorrelator.reduce_squares(this->generator.nPreampTriplesPerLoop,
|
||||
@@ -229,15 +234,45 @@ void MascotMultiplier<U>::after_correlation()
|
||||
this->macs.resize(3);
|
||||
MultJob job;
|
||||
this->inbox.pop(job);
|
||||
auto& generator = this->generator;
|
||||
array<int, 3> n_vals;
|
||||
for (int j = 0; j < 3; j++)
|
||||
{
|
||||
int nValues = this->generator.nTriplesPerLoop;
|
||||
n_vals[j] = generator.nTriplesPerLoop;
|
||||
if (this->generator.machine.check && (j % 2 == 0))
|
||||
nValues *= 2;
|
||||
this->auth_ot_ext.expand(0, nValues);
|
||||
this->auth_ot_ext.correlate(0, nValues,
|
||||
this->generator.valueBits[j], true);
|
||||
this->auth_ot_ext.reduce_squares(nValues, this->macs[j]);
|
||||
n_vals[j] *= 2;
|
||||
}
|
||||
if (generator.machine.fewer_rounds)
|
||||
{
|
||||
BitVector bits;
|
||||
int total = 0;
|
||||
for (int j = 0; j < 3; j++)
|
||||
{
|
||||
bits.append(generator.valueBits[j],
|
||||
n_vals[j] * T::Square::N_COLUMNS);
|
||||
total += n_vals[j];
|
||||
}
|
||||
this->auth_ot_ext.resize(bits.size());
|
||||
this->auth_ot_ext.expand(0, total);
|
||||
this->auth_ot_ext.correlate(0, total, bits, true);
|
||||
total = 0;
|
||||
for (int j = 0; j < 3; j++)
|
||||
{
|
||||
this->auth_ot_ext.reduce_squares(n_vals[j], this->macs[j], total);
|
||||
total += n_vals[j];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
this->auth_ot_ext.resize(n_vals[0] * T::Square::N_COLUMNS);
|
||||
for (int j = 0; j < 3; j++)
|
||||
{
|
||||
int nValues = n_vals[j];
|
||||
this->auth_ot_ext.expand(0, nValues);
|
||||
this->auth_ot_ext.correlate(0, nValues,
|
||||
this->generator.valueBits[j], true);
|
||||
this->auth_ot_ext.reduce_squares(nValues, this->macs[j]);
|
||||
}
|
||||
}
|
||||
this->outbox.push(job);
|
||||
}
|
||||
|
||||
@@ -41,3 +41,22 @@ void OTTripleSetup::close_connections()
|
||||
delete players[i];
|
||||
}
|
||||
}
|
||||
|
||||
OTTripleSetup OTTripleSetup::get_fresh()
|
||||
{
|
||||
OTTripleSetup res = *this;
|
||||
for (int i = 0; i < nparties - 1; i++)
|
||||
{
|
||||
BaseOT bot(nbase, 128, 0);
|
||||
bot.sender_inputs = baseSenderInputs[i];
|
||||
bot.receiver_outputs = baseReceiverOutputs[i];
|
||||
bot.set_seeds();
|
||||
bot.extend_length();
|
||||
baseSenderInputs[i] = bot.sender_inputs;
|
||||
baseReceiverOutputs[i] = bot.receiver_outputs;
|
||||
bot.extend_length();
|
||||
res.baseSenderInputs[i] = bot.sender_inputs;
|
||||
res.baseReceiverOutputs[i] = bot.receiver_outputs;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@
|
||||
*/
|
||||
class OTTripleSetup
|
||||
{
|
||||
vector<int> base_receiver_inputs;
|
||||
BitVector base_receiver_inputs;
|
||||
vector<BaseOT*> baseOTs;
|
||||
|
||||
PRNG G;
|
||||
@@ -25,10 +25,10 @@ public:
|
||||
vector< vector< vector<BitVector> > > baseSenderInputs;
|
||||
vector< vector<BitVector> > baseReceiverOutputs;
|
||||
|
||||
int get_nparties() { return nparties; }
|
||||
int get_nbase() { return nbase; }
|
||||
int get_my_num() { return my_num; }
|
||||
int get_base_receiver_input(int i) { return base_receiver_inputs[i]; }
|
||||
int get_nparties() const { return nparties; }
|
||||
int get_nbase() const { return nbase; }
|
||||
int get_my_num() const { return my_num; }
|
||||
int get_base_receiver_input(int i) const { return base_receiver_inputs[i]; }
|
||||
|
||||
OTTripleSetup(Player& N, bool real_OTs)
|
||||
: nparties(N.num_players()), my_num(N.my_num()), nbase(128)
|
||||
@@ -78,6 +78,8 @@ public:
|
||||
|
||||
//template <class T>
|
||||
//T get_mac_key();
|
||||
|
||||
OTTripleSetup get_fresh();
|
||||
};
|
||||
|
||||
|
||||
|
||||
@@ -16,13 +16,16 @@ public:
|
||||
T b;
|
||||
T c[N];
|
||||
|
||||
int repeat(int l)
|
||||
int repeat(int l, bool check)
|
||||
{
|
||||
switch (l)
|
||||
{
|
||||
case 0:
|
||||
case 2:
|
||||
return N;
|
||||
if (check)
|
||||
return N;
|
||||
else
|
||||
return 1;
|
||||
case 1:
|
||||
return 1;
|
||||
default:
|
||||
@@ -75,12 +78,12 @@ class PlainTriple : public Triple<T,N>
|
||||
{
|
||||
public:
|
||||
// this assumes that valueBits[1] is still set to the bits of b
|
||||
void to(vector<BitVector>& valueBits, int i)
|
||||
void to(vector<BitVector>& valueBits, int i, int repeat = N)
|
||||
{
|
||||
for (int j = 0; j < N; j++)
|
||||
{
|
||||
valueBits[0].set_portion(i * N + j, this->a[j]);
|
||||
valueBits[2].set_portion(i * N + j, this->c[j]);
|
||||
valueBits[0].set_portion(i * repeat + j, this->a[j]);
|
||||
valueBits[2].set_portion(i * repeat + j, this->c[j]);
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -123,12 +126,12 @@ public:
|
||||
{
|
||||
for (int l = 0; l < 3; l++)
|
||||
{
|
||||
int repeat = this->repeat(l);
|
||||
int repeat = this->repeat(l, generator.machine.check);
|
||||
for (int j = 0; j < repeat; j++)
|
||||
{
|
||||
T value = triple.byIndex(l,j);
|
||||
T mac;
|
||||
mac.mul(value, generator.machine.template get_mac_key<U>());
|
||||
mac.mul(value, generator.get_mac_key());
|
||||
for (int i = 0; i < generator.nparties-1; i++)
|
||||
mac += generator.ot_multipliers[i]->macs.at(l).at(iTriple * repeat + j);
|
||||
Share<T>& share = this->byIndex(l,j);
|
||||
|
||||
@@ -16,28 +16,21 @@ class GeneratorThread;
|
||||
|
||||
class MascotParams : virtual public OfflineParams
|
||||
{
|
||||
protected:
|
||||
gf2n_short mac_key2s;
|
||||
gf2n_long mac_key2l;
|
||||
gfp1 mac_keyp;
|
||||
Z2<128> mac_keyz;
|
||||
|
||||
public:
|
||||
string prep_data_dir;
|
||||
bool generateMACs;
|
||||
bool amplify;
|
||||
bool check;
|
||||
bool correlation_check;
|
||||
bool generateBits;
|
||||
bool use_extension;
|
||||
bool fewer_rounds;
|
||||
bool fiat_shamir;
|
||||
struct timeval start, stop;
|
||||
|
||||
MascotParams();
|
||||
|
||||
void set_passive();
|
||||
|
||||
template <class T>
|
||||
T get_mac_key();
|
||||
template <class T>
|
||||
void set_mac_key(T key);
|
||||
};
|
||||
|
||||
class TripleMachine : public OfflineMachineBase, public MascotParams
|
||||
@@ -45,6 +38,10 @@ class TripleMachine : public OfflineMachineBase, public MascotParams
|
||||
Names N[2];
|
||||
int nConnections;
|
||||
|
||||
gf2n mac_key2;
|
||||
gfp1 mac_keyp;
|
||||
Z2<128> mac_keyz;
|
||||
|
||||
public:
|
||||
int nloops;
|
||||
bool primeField;
|
||||
@@ -54,7 +51,8 @@ public:
|
||||
TripleMachine(int argc, const char** argv);
|
||||
|
||||
template<class T>
|
||||
GeneratorThread* new_generator(OTTripleSetup& setup, int i);
|
||||
GeneratorThread* new_generator(OTTripleSetup& setup, int i,
|
||||
typename T::mac_key_type mac_key);
|
||||
|
||||
void run();
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ public:
|
||||
string progname;
|
||||
int nthreads;
|
||||
|
||||
vector<vector<OTTripleSetup>> ot_setups;
|
||||
vector<OTTripleSetup> ot_setups;
|
||||
|
||||
static BaseMachine& s();
|
||||
|
||||
|
||||
@@ -94,6 +94,7 @@ public:
|
||||
virtual void purge() {}
|
||||
|
||||
virtual size_t data_sent() { return 0; }
|
||||
virtual NamedCommStats comm_stats() { return {}; }
|
||||
|
||||
virtual void get_three_no_count(Dtype dtype, T& a, T& b, T& c) = 0;
|
||||
virtual void get_two_no_count(Dtype dtype, T& a, T& b) = 0;
|
||||
@@ -112,6 +113,7 @@ public:
|
||||
virtual array<T, 3> get_triple(int n_bits);
|
||||
|
||||
virtual void buffer_triples() {}
|
||||
virtual void buffer_inverses() {}
|
||||
};
|
||||
|
||||
template<class T>
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
#include "FixInput.h"
|
||||
#include "FloatInput.h"
|
||||
|
||||
#include "IntInput.hpp"
|
||||
|
||||
template<class T>
|
||||
InputBase<T>::InputBase(ArithmeticProcessor* proc) :
|
||||
P(0), values_input(0)
|
||||
@@ -295,7 +297,7 @@ void InputBase<T>::input_mixed(SubProcessor<T>& Proc, const vector<int>& args,
|
||||
cout << "Please input " << U::NAME << "s:" << endl; \
|
||||
prepare<U>(Proc, player, &args[i + U::N_DEST + 1], size); \
|
||||
break;
|
||||
X(IntInput) X(FixInput) X(FloatInput)
|
||||
X(IntInput<typename T::clear>) X(FixInput) X(FloatInput)
|
||||
#undef X
|
||||
default:
|
||||
throw runtime_error("unknown input type: " + to_string(type));
|
||||
@@ -317,7 +319,7 @@ void InputBase<T>::input_mixed(SubProcessor<T>& Proc, const vector<int>& args,
|
||||
n_arg_tuple = U::N_DEST + U::N_PARAM + 2; \
|
||||
finalize<U>(Proc, args[i + n_arg_tuple - 1], &args[i + 1], size); \
|
||||
break;
|
||||
X(IntInput) X(FixInput) X(FloatInput)
|
||||
X(IntInput<typename T::clear>) X(FixInput) X(FloatInput)
|
||||
#undef X
|
||||
default:
|
||||
throw runtime_error("unknown input type: " + to_string(type));
|
||||
|
||||
@@ -61,6 +61,9 @@ enum
|
||||
USE_PREP = 0x1C,
|
||||
STARTGRIND = 0x1D,
|
||||
STOPGRIND = 0x1E,
|
||||
NPLAYERS = 0xE2,
|
||||
THRESHOLD = 0xE3,
|
||||
PLAYERID = 0xE4,
|
||||
// Addition
|
||||
ADDC = 0x20,
|
||||
ADDS = 0x21,
|
||||
|
||||
@@ -177,6 +177,9 @@ void BaseInstruction::parse_operands(istream& s, int pos)
|
||||
case PRINTCHRINT:
|
||||
case PRINTSTRINT:
|
||||
case PRINTINT:
|
||||
case NPLAYERS:
|
||||
case THRESHOLD:
|
||||
case PLAYERID:
|
||||
r[0]=get_int(s);
|
||||
break;
|
||||
// instructions with 3 registers + 1 integer operand
|
||||
@@ -442,6 +445,9 @@ int BaseInstruction::get_reg_type() const
|
||||
case CONVMODP:
|
||||
case GCONVGF2N:
|
||||
case RAND:
|
||||
case NPLAYERS:
|
||||
case THRESHOLD:
|
||||
case PLAYERID:
|
||||
return INT;
|
||||
case PREP:
|
||||
case USE_PREP:
|
||||
@@ -1046,10 +1052,10 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
Proc.temp.ans2.output(Proc.private_output, false);
|
||||
break;
|
||||
case INPUT:
|
||||
sint::Input::template input<IntInput>(Proc.Procp, start, size);
|
||||
sint::Input::template input<IntInput<typename sint::clear>>(Proc.Procp, start, size);
|
||||
return;
|
||||
case GINPUT:
|
||||
sgf2n::Input::template input<IntInput>(Proc.Proc2, start, size);
|
||||
sgf2n::Input::template input<IntInput<typename sgf2n::clear>>(Proc.Proc2, start, size);
|
||||
return;
|
||||
case INPUTFIX:
|
||||
sint::Input::template input<FixInput>(Proc.Procp, start, size);
|
||||
@@ -1404,6 +1410,15 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
case STOPGRIND:
|
||||
CALLGRIND_STOP_INSTRUMENTATION;
|
||||
break;
|
||||
case NPLAYERS:
|
||||
Proc.write_Ci(r[0], Proc.P.num_players());
|
||||
break;
|
||||
case THRESHOLD:
|
||||
Proc.write_Ci(r[0], sint::threshold(Proc.P.num_players()));
|
||||
break;
|
||||
case PLAYERID:
|
||||
Proc.write_Ci(r[0], Proc.P.my_num());
|
||||
break;
|
||||
// ***
|
||||
// TODO: read/write shared GF(2^n) data instructions
|
||||
// ***
|
||||
|
||||
@@ -1,14 +0,0 @@
|
||||
/*
|
||||
* IntInput.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "IntInput.h"
|
||||
|
||||
const char* IntInput::NAME = "integer";
|
||||
|
||||
void IntInput::read(std::istream& in, const int* params)
|
||||
{
|
||||
(void) params;
|
||||
in >> items[0];
|
||||
}
|
||||
@@ -8,6 +8,7 @@
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<class T>
|
||||
class IntInput
|
||||
{
|
||||
public:
|
||||
@@ -17,7 +18,7 @@ public:
|
||||
|
||||
const static int TYPE = 0;
|
||||
|
||||
long items[N_DEST];
|
||||
T items[N_DEST];
|
||||
|
||||
void read(std::istream& in, const int* params);
|
||||
};
|
||||
|
||||
15
Processor/IntInput.hpp
Normal file
15
Processor/IntInput.hpp
Normal file
@@ -0,0 +1,15 @@
|
||||
/*
|
||||
* IntInput.cpp
|
||||
*
|
||||
*/
|
||||
|
||||
#include "IntInput.h"
|
||||
|
||||
template<class T>
|
||||
const char* IntInput<T>::NAME = "integer";
|
||||
|
||||
template<class T>
|
||||
void IntInput<T>::read(std::istream& in, const int*)
|
||||
{
|
||||
in >> items[0];
|
||||
}
|
||||
@@ -127,10 +127,8 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
|
||||
P = new CryptoPlayer(playerNames, 0xF000);
|
||||
else
|
||||
P = new PlainPlayer(playerNames, 0xF000);
|
||||
ot_setups.resize(nthreads);
|
||||
for (int i = 0; i < nthreads; i++)
|
||||
for (int j = 0; j < 3; j++)
|
||||
ot_setups.at(i).push_back({ *P, true });
|
||||
ot_setups.push_back({ *P, true });
|
||||
delete P;
|
||||
}
|
||||
|
||||
|
||||
@@ -52,8 +52,6 @@ public:
|
||||
Preprocessing<T>& DataF, Player& P);
|
||||
|
||||
// Access to PO (via calls to POpen start/stop)
|
||||
void POpen_Start(const vector<int>& reg,const Player& P,int size);
|
||||
void POpen_Stop(const vector<int>& reg,const Player& P,int size);
|
||||
void POpen(const vector<int>& reg,const Player& P,int size);
|
||||
|
||||
void muls(const vector<int>& reg, int size);
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
|
||||
#include "Protocols/ReplicatedInput.hpp"
|
||||
#include "Protocols/ReplicatedPrivateOutput.hpp"
|
||||
#include "Processor/ProcessorBase.hpp"
|
||||
|
||||
#include <sodium.h>
|
||||
#include <string>
|
||||
@@ -406,15 +407,16 @@ void Processor<sint, sgf2n>::write_shares_to_file(const vector<int>& data_regist
|
||||
}
|
||||
|
||||
template <class T>
|
||||
void SubProcessor<T>::POpen_Start(const vector<int>& reg,const Player& P,int size)
|
||||
void SubProcessor<T>::POpen(const vector<int>& reg,const Player& P,int size)
|
||||
{
|
||||
int sz=reg.size();
|
||||
assert(reg.size() % 2 == 0);
|
||||
int sz=reg.size() / 2;
|
||||
Sh_PO.clear();
|
||||
Sh_PO.reserve(sz*size);
|
||||
if (size>1)
|
||||
{
|
||||
for (typename vector<int>::const_iterator reg_it=reg.begin();
|
||||
reg_it!=reg.end(); reg_it++)
|
||||
for (typename vector<int>::const_iterator reg_it=reg.begin() + 1;
|
||||
reg_it < reg.end(); reg_it += 2)
|
||||
{
|
||||
auto begin=S.begin()+*reg_it;
|
||||
Sh_PO.insert(Sh_PO.end(),begin,begin+size);
|
||||
@@ -423,24 +425,15 @@ void SubProcessor<T>::POpen_Start(const vector<int>& reg,const Player& P,int siz
|
||||
else
|
||||
{
|
||||
for (int i=0; i<sz; i++)
|
||||
{ Sh_PO.push_back(S[reg[i]]); }
|
||||
{ Sh_PO.push_back(S[reg[2 * i + 1]]); }
|
||||
}
|
||||
PO.resize(sz*size);
|
||||
MC.POpen_Begin(PO,Sh_PO,P);
|
||||
}
|
||||
|
||||
|
||||
template <class T>
|
||||
void SubProcessor<T>::POpen_Stop(const vector<int>& reg,const Player& P,int size)
|
||||
{
|
||||
int sz=reg.size();
|
||||
PO.resize(sz*size);
|
||||
MC.POpen_End(PO,Sh_PO,P);
|
||||
MC.POpen(PO,Sh_PO,P);
|
||||
if (size>1)
|
||||
{
|
||||
auto PO_it=PO.begin();
|
||||
for (typename vector<int>::const_iterator reg_it=reg.begin();
|
||||
reg_it!=reg.end(); reg_it++)
|
||||
reg_it!=reg.end(); reg_it += 2)
|
||||
{
|
||||
for (auto C_it=C.begin()+*reg_it;
|
||||
C_it!=C.begin()+*reg_it+size; C_it++)
|
||||
@@ -452,36 +445,16 @@ void SubProcessor<T>::POpen_Stop(const vector<int>& reg,const Player& P,int size
|
||||
}
|
||||
else
|
||||
{
|
||||
for (unsigned int i=0; i<reg.size(); i++)
|
||||
{ C[reg[i]] = PO[i]; }
|
||||
for (unsigned int i = 0; i < reg.size() / 2; i++)
|
||||
{
|
||||
C[reg[2 * i]] = PO[i];
|
||||
}
|
||||
}
|
||||
|
||||
Proc.sent += reg.size() * size;
|
||||
Proc.rounds++;
|
||||
}
|
||||
|
||||
inline void unzip_open(vector<int>& dest, vector<int>& source, const vector<int>& reg)
|
||||
{
|
||||
int n = reg.size() / 2;
|
||||
source.resize(n);
|
||||
dest.resize(n);
|
||||
for (int i = 0; i < n; i++)
|
||||
{
|
||||
source[i] = reg[2 * i + 1];
|
||||
dest[i] = reg[2 * i];
|
||||
}
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::POpen(const vector<int>& reg, const Player& P,
|
||||
int size)
|
||||
{
|
||||
vector<int> source, dest;
|
||||
unzip_open(dest, source, reg);
|
||||
POpen_Start(source, P, size);
|
||||
POpen_Stop(dest, P, size);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::muls(const vector<int>& reg, int size)
|
||||
{
|
||||
|
||||
@@ -3,6 +3,9 @@
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROCESSOR_PROCESSORBASE_HPP_
|
||||
#define PROCESSOR_PROCESSORBASE_HPP_
|
||||
|
||||
#include "ProcessorBase.h"
|
||||
#include "IntInput.h"
|
||||
#include "FixInput.h"
|
||||
@@ -11,6 +14,7 @@
|
||||
|
||||
#include <iostream>
|
||||
|
||||
inline
|
||||
void ProcessorBase::open_input_file(const string& name)
|
||||
{
|
||||
#ifdef DEBUG_FILES
|
||||
@@ -20,6 +24,7 @@ void ProcessorBase::open_input_file(const string& name)
|
||||
input_filename = name;
|
||||
}
|
||||
|
||||
inline
|
||||
void ProcessorBase::open_input_file(int my_num, int thread_num)
|
||||
{
|
||||
string input_file = "Player-Data/Input-P" + to_string(my_num) + "-" + to_string(thread_num);
|
||||
@@ -54,6 +59,4 @@ T ProcessorBase::get_input(istream& input_file, const string& input_filename, co
|
||||
return res;
|
||||
}
|
||||
|
||||
template IntInput ProcessorBase::get_input(bool, const int*);
|
||||
template FixInput ProcessorBase::get_input(bool, const int*);
|
||||
template FloatInput ProcessorBase::get_input(bool, const int*);
|
||||
#endif
|
||||
@@ -129,7 +129,7 @@ def expandAESKey(cipherKey, Nr = 10, Nb = 4, Nk = 4):
|
||||
temp[2] = box.apply_sbox(temp[2])
|
||||
temp[3] = box.apply_sbox(temp[3])
|
||||
|
||||
temp[0] = temp[0] + ApplyEmbedding(rcon[int(i/Nk)])
|
||||
temp[0] = temp[0] + ApplyEmbedding(rcon[int(i//Nk)])
|
||||
|
||||
for j in range(4):
|
||||
round_key[4 * i + j] = round_key[4 * (i - Nk) + j] + temp[j]
|
||||
@@ -233,7 +233,7 @@ def inverseMod(val):
|
||||
|
||||
for idx in range(40):
|
||||
if idx % 5 == 0:
|
||||
bd_val[idx] = raw_bit_dec[idx / 5]
|
||||
bd_val[idx] = raw_bit_dec[idx // 5]
|
||||
|
||||
bd_squared = bd_val
|
||||
squared_index = 2
|
||||
|
||||
@@ -118,10 +118,10 @@ def main():
|
||||
return True
|
||||
|
||||
if n_rounds > 0:
|
||||
print 'run %d rounds' % n_rounds
|
||||
print('run %d rounds' % n_rounds)
|
||||
for_range(n_rounds)(game_loop)
|
||||
else:
|
||||
print 'run forever'
|
||||
print('run forever')
|
||||
do_while(game_loop)
|
||||
|
||||
main()
|
||||
|
||||
@@ -13,7 +13,7 @@ if len(program.args) > 2:
|
||||
m = int(program.args[2])
|
||||
|
||||
pack = min(n, 50)
|
||||
n = (n + pack - 1) / pack
|
||||
n = (n + pack - 1) // pack
|
||||
|
||||
a = sbit(1)
|
||||
b = sbit(1, n=pack)
|
||||
|
||||
@@ -50,9 +50,9 @@ test_decryption = True
|
||||
instructions_base.set_global_vector_size(n_parallel)
|
||||
|
||||
if use_mimc_prf:
|
||||
execfile('./Programs/Source/prf_mimc.mpc')
|
||||
exec(compile(__builtins__['open']('./Programs/Source/prf_mimc.mpc').read(), './Programs/Source/prf_mimc.mpc', 'exec'))
|
||||
elif use_leg_prf:
|
||||
execfile('./Programs/Source/prf_leg.mpc')
|
||||
exec(compile(__builtins__['open']('./Programs/Source/prf_leg.mpc').read(), './Programs/Source/prf_leg.mpc', 'exec'))
|
||||
|
||||
class HMAC(object):
|
||||
def __init__(self, _enc):
|
||||
@@ -97,7 +97,7 @@ class NonceEncryptMAC(object):
|
||||
def get_long_random(self, nbits):
|
||||
""" Returns random cint() % 2^{nbits} """
|
||||
result = cint(0)
|
||||
for i in range(nbits / 30):
|
||||
for i in range(nbits // 30):
|
||||
result += cint(regint.get_random(30))
|
||||
result <<= 30
|
||||
|
||||
@@ -178,7 +178,7 @@ def time_private_mac(n_total, n_parallel, nmessages):
|
||||
|
||||
# Benchmark n_total HtMAC's while executing in parallel n_parallel
|
||||
start_timer(1)
|
||||
@for_range(n_total / n_parallel)
|
||||
@for_range(n_total // n_parallel)
|
||||
def block(index):
|
||||
# Re-use off-line data after n_parallel runs for benchmarking purposes.
|
||||
# If real system-use need to initialize num_calls with a larger constant.
|
||||
|
||||
@@ -8,7 +8,7 @@ ml.set_n_threads(8)
|
||||
debug = False
|
||||
|
||||
if 'halfprec' in program.args:
|
||||
print '8-bit precision after point'
|
||||
print('8-bit precision after point')
|
||||
sfix.set_precision(8, 31)
|
||||
cfix.set_precision(8, 31)
|
||||
else:
|
||||
@@ -26,7 +26,7 @@ n_features = 12634
|
||||
|
||||
if len(program.args) > 2:
|
||||
if 'bc' in program.args:
|
||||
print 'Compiling for BC-TCGA'
|
||||
print('Compiling for BC-TCGA')
|
||||
n_examples = 472
|
||||
n_normal = 49
|
||||
n_features = 17814
|
||||
@@ -41,7 +41,7 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
print 'Using %d threads' % ml.Layer.n_threads
|
||||
print('Using %d threads' % ml.Layer.n_threads)
|
||||
|
||||
n_fold = 5
|
||||
test_share = 1. / n_fold
|
||||
@@ -63,8 +63,8 @@ else:
|
||||
n_test = sum(n_tests)
|
||||
|
||||
indices = [regint.Array(n) for n in n_ex]
|
||||
indices[0].assign(range(n_pos, n_pos + n_normal))
|
||||
indices[1].assign(range(n_pos))
|
||||
indices[0].assign(list(range(n_pos, n_pos + n_normal)))
|
||||
indices[1].assign(list(range(n_pos)))
|
||||
|
||||
test = regint.Array(n_test)
|
||||
|
||||
@@ -97,7 +97,7 @@ for arg in program.args:
|
||||
m = re.match('tol=(.*)', arg)
|
||||
if m:
|
||||
sgd.tol = float(m.group(1))
|
||||
print 'Stop with tolerance', sgd.tol
|
||||
print('Stop with tolerance', sgd.tol)
|
||||
|
||||
sum_acc = cfix.MemValue(0)
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ sbitfix.set_precision(16, 32)
|
||||
def test(a, b, value_type=None):
|
||||
try:
|
||||
b = int(round((b * (1 << a.f))))
|
||||
if b < 0:
|
||||
b += 2 ** sbitfix.k
|
||||
if b < 0:
|
||||
b += 2 ** sbitfix.k
|
||||
a = a.v.reveal()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@@ -55,12 +55,12 @@ def f(_):
|
||||
|
||||
def thread():
|
||||
i = get_arg()
|
||||
n_per_thread = n_inputs / n_threads
|
||||
n_per_thread = n_inputs // n_threads
|
||||
if n_per_thread % 2 != 0:
|
||||
raise Exception('Number of inputs must be divisible by 2')
|
||||
start = i * n_per_thread
|
||||
tuples = [bid_sort(bids[start+2*j], bids[start+2*j+1]) \
|
||||
for j in range(n_per_thread / 2)]
|
||||
for j in range(n_per_thread // 2)]
|
||||
first, second = util.tree_reduce(first_and_second, tuples)
|
||||
results[2*i], results[2*i+1] = first, second
|
||||
|
||||
|
||||
@@ -29,6 +29,8 @@ class Beaver : public ProtocolBase<T>
|
||||
typename T::MAC_Check* MC;
|
||||
|
||||
public:
|
||||
static const bool uses_triples = true;
|
||||
|
||||
Player& P;
|
||||
|
||||
Beaver(Player& P) : prep(0), MC(0), P(P) {}
|
||||
@@ -39,6 +41,9 @@ public:
|
||||
void exchange();
|
||||
T finalize_mul(int n = -1);
|
||||
|
||||
void start_exchange();
|
||||
void stop_exchange();
|
||||
|
||||
int get_n_relevant_players() { return P.num_players(); }
|
||||
};
|
||||
|
||||
|
||||
@@ -50,6 +50,20 @@ void Beaver<T>::exchange()
|
||||
triple = triples.begin();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Beaver<T>::start_exchange()
|
||||
{
|
||||
MC->POpen_Begin(opened, shares, P);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void Beaver<T>::stop_exchange()
|
||||
{
|
||||
MC->POpen_End(opened, shares, P);
|
||||
it = opened.begin();
|
||||
triple = triples.begin();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
T Beaver<T>::finalize_mul(int n)
|
||||
{
|
||||
|
||||
39
Protocols/HemiPrep.h
Normal file
39
Protocols/HemiPrep.h
Normal file
@@ -0,0 +1,39 @@
|
||||
/*
|
||||
* HemiPrep.h
|
||||
*
|
||||
*/
|
||||
|
||||
#ifndef PROTOCOLS_HEMIPREP_H_
|
||||
#define PROTOCOLS_HEMIPREP_H_
|
||||
|
||||
#include "ReplicatedPrep.h"
|
||||
#include "FHEOffline/Multiplier.h"
|
||||
|
||||
template<class T>
|
||||
class HemiPrep : public SemiHonestRingPrep<T>
|
||||
{
|
||||
typedef typename T::clear::FD FD;
|
||||
|
||||
static PairwiseMachine* pairwise_machine;
|
||||
static Lock lock;
|
||||
|
||||
vector<Multiplier<FD>*> multipliers;
|
||||
|
||||
SeededPRNG G;
|
||||
|
||||
map<string, Timer> timers;
|
||||
|
||||
public:
|
||||
static void basic_setup(Player& P);
|
||||
static void teardown();
|
||||
|
||||
HemiPrep(SubProcessor<T>* proc, DataPositions& usage) :
|
||||
RingPrep<T>(proc, usage), SemiHonestRingPrep<T>(proc, usage)
|
||||
{
|
||||
}
|
||||
|
||||
void buffer_triples();
|
||||
void buffer_inverses();
|
||||
};
|
||||
|
||||
#endif /* PROTOCOLS_HEMIPREP_H_ */
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user