Dijkstra's algorithm in binary circuits.

This commit is contained in:
Marcel Keller
2022-10-24 22:17:55 +11:00
parent ed7a474300
commit 6f553cd1f2
5 changed files with 45 additions and 22 deletions

View File

@@ -786,6 +786,8 @@ class sbitvec(_vec, _bit):
return self.from_vec(x.zero_if_not(condition) for x in self.v)
def __str__(self):
return 'sbitvec(%d)' % n
sbitvecn.basic_type = sbitvecn
sbitvecn.reg_type = 'sb'
return sbitvecn
@classmethod
def from_vec(cls, vector):
@@ -859,7 +861,6 @@ class sbitvec(_vec, _bit):
def __invert__(self):
return self.from_vec(~x for x in self.v)
def if_else(self, x, y):
assert(len(self.v) == 1)
return util.if_else(self.v[0], x, y)
def __iter__(self):
return iter(self.v)
@@ -873,6 +874,7 @@ class sbitvec(_vec, _bit):
return cls.from_vec(other.v)
else:
return cls(other)
hard_conv = conv
@property
def size(self):
if not self.v or util.is_constant(self.v[0]):
@@ -1040,7 +1042,10 @@ sbits.dynamic_array = DynamicArray
cbits.dynamic_array = Array
def _complement_two_extend(bits, k):
return bits[:k] + [bits[-1]] * (k - len(bits))
if len(bits) == 1:
return bits + [0] * (k - len(bits))
else:
return bits[:k] + [bits[-1]] * (k - len(bits))
class _sbitintbase:
def extend(self, n):
@@ -1226,10 +1231,9 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
if util.is_zero(other):
return self
other = self.coerce(other)
assert(len(self.v) == len(other.v))
a, b = self.expand(other)
v = sbitint.bit_adder(a, b)
return self.from_vec(v)
return self.get_type(len(v)).from_vec(v)
__radd__ = __add__
def __mul__(self, other):
if isinstance(other, sbits):

View File

@@ -99,7 +99,7 @@ class HeapQ(object):
bits.reverse()
bits = [0] + floatingpoint.PreOR(bits, self.levels)
bits = [bits[i+1] - bits[i] for i in range(self.levels)]
shift = sum([bit << i for i,bit in enumerate(bits)])
shift = self.int_type.bit_compose(bits)
childpos = MemValue(start * shift)
@for_range(self.levels - 1)
def f(i):
@@ -215,12 +215,13 @@ class HeapQ(object):
print_ln()
print_ln()
def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=sint):
basic_type = int_type.basic_type
def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None):
vert_loops = n_loops * e_index.size // edges.size \
if n_loops else -1
dist = oram_type(e_index.size, entry_size=(32,log2(e_index.size)), \
init_rounds=vert_loops, value_type=basic_type)
init_rounds=vert_loops, value_type=int_type)
int_type = dist.value_type
basic_type = int_type.basic_type
#visited = ORAM(e_index.size)
#previous = oram_type(e_index.size)
Q = HeapQ(e_index.size, oram_type, init_rounds=vert_loops, \
@@ -240,7 +241,7 @@ def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=sint):
u = MemValue(basic_type(0))
@for_range(n_loops or edges.size)
def f(i):
cint(i).print_reg('loop')
print_ln('loop %s', i)
time()
u.write(if_else(last_edge, Q.pop(last_edge), u))
#visited.access(u, True, last_edge)

View File

@@ -290,13 +290,13 @@ def get_arg():
ldarg(res)
return res
def make_array(l):
def make_array(l, t=None):
if isinstance(l, program.Tape.Register):
res = Array(len(l), type(l))
res = Array(len(l), t or type(l))
res[:] = l
else:
l = list(l)
res = Array(len(l), type(l[0]) if l else cint)
res = Array(len(l), t or type(l[0]) if l else cint)
res.assign(l)
return res

View File

@@ -805,11 +805,11 @@ class RefTrivialORAM(EndRecursiveEviction):
class TrivialORAM(RefTrivialORAM, AbstractORAM):
""" Trivial ORAM (obviously). """
ref_type = RefTrivialORAM
def __init__(self, size, value_type=sint, value_length=1, index_size=None, \
def __init__(self, size, value_type=None, value_length=1, index_size=None, \
entry_size=None, contiguous=True, init_rounds=-1):
self.index_size = index_size or log2(size)
self.value_type = value_type
self.index_type = value_type.get_type(self.index_size)
self.value_type = value_type or sint
self.index_type = self.value_type.get_type(self.index_size)
if entry_size is None:
self.value_length = value_length
self.entry_size = [None] * value_length
@@ -880,7 +880,9 @@ class LinearORAM(TrivialORAM):
empty_entry = self.empty_entry(False)
demux_array(bit_decompose(index, self.index_size), \
self.index_vector)
new_value = make_array(new_value)
new_value = make_array(
new_value, self.value_type.get_type(
max(x or 0 for x in self.entry_size)))
@for_range_multithread(get_n_threads(self.size), n_parallel, self.size)
def f(i):
entry = self.ram[i]
@@ -896,7 +898,9 @@ class LinearORAM(TrivialORAM):
empty_entry = self.empty_entry(False)
index_vector = \
demux_array(bit_decompose(index, self.index_size))
new_value = make_array(new_value)
new_value = make_array(
new_value, self.value_type.get_type(
max(x or 0 for x in self.entry_size)))
new_empty = MemValue(new_empty)
write = MemValue(write)
@map_sum(get_n_threads(self.size), n_parallel, self.size, \
@@ -1680,7 +1684,7 @@ class OneLevelORAM(TreeORAM):
class BinaryORAM:
def __init__(self, size, value_type=None, **kwargs):
import circuit_oram
from GC import types
from Compiler.GC import types
n_bits = int(get_program().options.binary)
self.value_type = value_type or types.sbitintvec.get_type(n_bits)
self.index_type = self.value_type
@@ -1689,13 +1693,26 @@ class BinaryORAM:
kwargs['entry_size'] = n_bits
self.oram = circuit_oram.OptimalCircuitORAM(
size, value_type=oram_value_type, **kwargs)
self.size = size
def get_index(self, index):
return self.oram.value_type(self.index_type.conv(index).elements()[0])
def __setitem__(self, index, value):
self.oram[self.get_index(index)] = self.oram.value_type(
self.value_type.conv(value).elements()[0])
value = list(self.oram.value_type(
self.value_type.conv(v).elements()[0]) for v in tuplify(value))
self.oram[self.get_index(index)] = value
def __getitem__(self, index):
return self.value_type(self.oram[self.get_index(index)])
value = self.oram[self.get_index(index)]
return untuplify(tuple(self.value_type(v) for v in tuplify(value)))
def read(self, index):
return self.oram.read(index)
def read_and_maybe_remove(self, index):
return self.oram.read_and_maybe_remove(index)
def access(self, *args):
return self.oram.access(*args)
def add(self, *args, **kwargs):
return self.oram.add(*args, **kwargs)
def delete(self, *args, **kwargs):
return self.oram.delete(*args, **kwargs)
def OptimalORAM(size,*args,**kwargs):
""" Create an ORAM instance suitable for the size based on

View File

@@ -6705,9 +6705,10 @@ class MemValue(_mem):
:return: relevant basic type instance """
self.check()
if program.curr_block != self.last_write_block:
from Compiler.GC.types import sbitvec
self.register = self.value_type.load_mem(
self.address, size=self.size \
if issubclass(self.value_type, _register) else None)
if issubclass(self.value_type, (_register, sbitvec)) else None)
self.last_write_block = program.curr_block
return self.register