mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-10 05:57:57 -05:00
Dijkstra's algorithm in binary circuits.
This commit is contained in:
@@ -786,6 +786,8 @@ class sbitvec(_vec, _bit):
|
||||
return self.from_vec(x.zero_if_not(condition) for x in self.v)
|
||||
def __str__(self):
|
||||
return 'sbitvec(%d)' % n
|
||||
sbitvecn.basic_type = sbitvecn
|
||||
sbitvecn.reg_type = 'sb'
|
||||
return sbitvecn
|
||||
@classmethod
|
||||
def from_vec(cls, vector):
|
||||
@@ -859,7 +861,6 @@ class sbitvec(_vec, _bit):
|
||||
def __invert__(self):
|
||||
return self.from_vec(~x for x in self.v)
|
||||
def if_else(self, x, y):
|
||||
assert(len(self.v) == 1)
|
||||
return util.if_else(self.v[0], x, y)
|
||||
def __iter__(self):
|
||||
return iter(self.v)
|
||||
@@ -873,6 +874,7 @@ class sbitvec(_vec, _bit):
|
||||
return cls.from_vec(other.v)
|
||||
else:
|
||||
return cls(other)
|
||||
hard_conv = conv
|
||||
@property
|
||||
def size(self):
|
||||
if not self.v or util.is_constant(self.v[0]):
|
||||
@@ -1040,7 +1042,10 @@ sbits.dynamic_array = DynamicArray
|
||||
cbits.dynamic_array = Array
|
||||
|
||||
def _complement_two_extend(bits, k):
|
||||
return bits[:k] + [bits[-1]] * (k - len(bits))
|
||||
if len(bits) == 1:
|
||||
return bits + [0] * (k - len(bits))
|
||||
else:
|
||||
return bits[:k] + [bits[-1]] * (k - len(bits))
|
||||
|
||||
class _sbitintbase:
|
||||
def extend(self, n):
|
||||
@@ -1226,10 +1231,9 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
|
||||
if util.is_zero(other):
|
||||
return self
|
||||
other = self.coerce(other)
|
||||
assert(len(self.v) == len(other.v))
|
||||
a, b = self.expand(other)
|
||||
v = sbitint.bit_adder(a, b)
|
||||
return self.from_vec(v)
|
||||
return self.get_type(len(v)).from_vec(v)
|
||||
__radd__ = __add__
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, sbits):
|
||||
|
||||
@@ -99,7 +99,7 @@ class HeapQ(object):
|
||||
bits.reverse()
|
||||
bits = [0] + floatingpoint.PreOR(bits, self.levels)
|
||||
bits = [bits[i+1] - bits[i] for i in range(self.levels)]
|
||||
shift = sum([bit << i for i,bit in enumerate(bits)])
|
||||
shift = self.int_type.bit_compose(bits)
|
||||
childpos = MemValue(start * shift)
|
||||
@for_range(self.levels - 1)
|
||||
def f(i):
|
||||
@@ -215,12 +215,13 @@ class HeapQ(object):
|
||||
print_ln()
|
||||
print_ln()
|
||||
|
||||
def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=sint):
|
||||
basic_type = int_type.basic_type
|
||||
def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None):
|
||||
vert_loops = n_loops * e_index.size // edges.size \
|
||||
if n_loops else -1
|
||||
dist = oram_type(e_index.size, entry_size=(32,log2(e_index.size)), \
|
||||
init_rounds=vert_loops, value_type=basic_type)
|
||||
init_rounds=vert_loops, value_type=int_type)
|
||||
int_type = dist.value_type
|
||||
basic_type = int_type.basic_type
|
||||
#visited = ORAM(e_index.size)
|
||||
#previous = oram_type(e_index.size)
|
||||
Q = HeapQ(e_index.size, oram_type, init_rounds=vert_loops, \
|
||||
@@ -240,7 +241,7 @@ def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=sint):
|
||||
u = MemValue(basic_type(0))
|
||||
@for_range(n_loops or edges.size)
|
||||
def f(i):
|
||||
cint(i).print_reg('loop')
|
||||
print_ln('loop %s', i)
|
||||
time()
|
||||
u.write(if_else(last_edge, Q.pop(last_edge), u))
|
||||
#visited.access(u, True, last_edge)
|
||||
|
||||
@@ -290,13 +290,13 @@ def get_arg():
|
||||
ldarg(res)
|
||||
return res
|
||||
|
||||
def make_array(l):
|
||||
def make_array(l, t=None):
|
||||
if isinstance(l, program.Tape.Register):
|
||||
res = Array(len(l), type(l))
|
||||
res = Array(len(l), t or type(l))
|
||||
res[:] = l
|
||||
else:
|
||||
l = list(l)
|
||||
res = Array(len(l), type(l[0]) if l else cint)
|
||||
res = Array(len(l), t or type(l[0]) if l else cint)
|
||||
res.assign(l)
|
||||
return res
|
||||
|
||||
|
||||
@@ -805,11 +805,11 @@ class RefTrivialORAM(EndRecursiveEviction):
|
||||
class TrivialORAM(RefTrivialORAM, AbstractORAM):
|
||||
""" Trivial ORAM (obviously). """
|
||||
ref_type = RefTrivialORAM
|
||||
def __init__(self, size, value_type=sint, value_length=1, index_size=None, \
|
||||
def __init__(self, size, value_type=None, value_length=1, index_size=None, \
|
||||
entry_size=None, contiguous=True, init_rounds=-1):
|
||||
self.index_size = index_size or log2(size)
|
||||
self.value_type = value_type
|
||||
self.index_type = value_type.get_type(self.index_size)
|
||||
self.value_type = value_type or sint
|
||||
self.index_type = self.value_type.get_type(self.index_size)
|
||||
if entry_size is None:
|
||||
self.value_length = value_length
|
||||
self.entry_size = [None] * value_length
|
||||
@@ -880,7 +880,9 @@ class LinearORAM(TrivialORAM):
|
||||
empty_entry = self.empty_entry(False)
|
||||
demux_array(bit_decompose(index, self.index_size), \
|
||||
self.index_vector)
|
||||
new_value = make_array(new_value)
|
||||
new_value = make_array(
|
||||
new_value, self.value_type.get_type(
|
||||
max(x or 0 for x in self.entry_size)))
|
||||
@for_range_multithread(get_n_threads(self.size), n_parallel, self.size)
|
||||
def f(i):
|
||||
entry = self.ram[i]
|
||||
@@ -896,7 +898,9 @@ class LinearORAM(TrivialORAM):
|
||||
empty_entry = self.empty_entry(False)
|
||||
index_vector = \
|
||||
demux_array(bit_decompose(index, self.index_size))
|
||||
new_value = make_array(new_value)
|
||||
new_value = make_array(
|
||||
new_value, self.value_type.get_type(
|
||||
max(x or 0 for x in self.entry_size)))
|
||||
new_empty = MemValue(new_empty)
|
||||
write = MemValue(write)
|
||||
@map_sum(get_n_threads(self.size), n_parallel, self.size, \
|
||||
@@ -1680,7 +1684,7 @@ class OneLevelORAM(TreeORAM):
|
||||
class BinaryORAM:
|
||||
def __init__(self, size, value_type=None, **kwargs):
|
||||
import circuit_oram
|
||||
from GC import types
|
||||
from Compiler.GC import types
|
||||
n_bits = int(get_program().options.binary)
|
||||
self.value_type = value_type or types.sbitintvec.get_type(n_bits)
|
||||
self.index_type = self.value_type
|
||||
@@ -1689,13 +1693,26 @@ class BinaryORAM:
|
||||
kwargs['entry_size'] = n_bits
|
||||
self.oram = circuit_oram.OptimalCircuitORAM(
|
||||
size, value_type=oram_value_type, **kwargs)
|
||||
self.size = size
|
||||
def get_index(self, index):
|
||||
return self.oram.value_type(self.index_type.conv(index).elements()[0])
|
||||
def __setitem__(self, index, value):
|
||||
self.oram[self.get_index(index)] = self.oram.value_type(
|
||||
self.value_type.conv(value).elements()[0])
|
||||
value = list(self.oram.value_type(
|
||||
self.value_type.conv(v).elements()[0]) for v in tuplify(value))
|
||||
self.oram[self.get_index(index)] = value
|
||||
def __getitem__(self, index):
|
||||
return self.value_type(self.oram[self.get_index(index)])
|
||||
value = self.oram[self.get_index(index)]
|
||||
return untuplify(tuple(self.value_type(v) for v in tuplify(value)))
|
||||
def read(self, index):
|
||||
return self.oram.read(index)
|
||||
def read_and_maybe_remove(self, index):
|
||||
return self.oram.read_and_maybe_remove(index)
|
||||
def access(self, *args):
|
||||
return self.oram.access(*args)
|
||||
def add(self, *args, **kwargs):
|
||||
return self.oram.add(*args, **kwargs)
|
||||
def delete(self, *args, **kwargs):
|
||||
return self.oram.delete(*args, **kwargs)
|
||||
|
||||
def OptimalORAM(size,*args,**kwargs):
|
||||
""" Create an ORAM instance suitable for the size based on
|
||||
|
||||
@@ -6705,9 +6705,10 @@ class MemValue(_mem):
|
||||
:return: relevant basic type instance """
|
||||
self.check()
|
||||
if program.curr_block != self.last_write_block:
|
||||
from Compiler.GC.types import sbitvec
|
||||
self.register = self.value_type.load_mem(
|
||||
self.address, size=self.size \
|
||||
if issubclass(self.value_type, _register) else None)
|
||||
if issubclass(self.value_type, (_register, sbitvec)) else None)
|
||||
self.last_write_block = program.curr_block
|
||||
return self.register
|
||||
|
||||
|
||||
Reference in New Issue
Block a user