mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
1319 lines
44 KiB
Python
1319 lines
44 KiB
Python
from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat
|
|
from Compiler.instructions import *
|
|
from Compiler.util import tuplify,untuplify
|
|
from Compiler import instructions,instructions_base,comparison,program,util
|
|
import inspect,math
|
|
import random
|
|
import collections
|
|
|
|
def get_program():
|
|
return instructions.program
|
|
def get_tape():
|
|
return get_program().curr_tape
|
|
def get_block():
|
|
return get_program().curr_block
|
|
|
|
def vectorize(function):
|
|
def vectorized_function(*args, **kwargs):
|
|
if len(args) > 0 and isinstance(args[0], program.Tape.Register):
|
|
instructions_base.set_global_vector_size(args[0].size)
|
|
res = function(*args, **kwargs)
|
|
instructions_base.reset_global_vector_size()
|
|
elif 'size' in kwargs:
|
|
instructions_base.set_global_vector_size(kwargs['size'])
|
|
del kwargs['size']
|
|
res = function(*args, **kwargs)
|
|
instructions_base.reset_global_vector_size()
|
|
else:
|
|
res = function(*args, **kwargs)
|
|
return res
|
|
vectorized_function.__name__ = function.__name__
|
|
return vectorized_function
|
|
|
|
def set_instruction_type(function):
|
|
def instruction_typed_function(*args, **kwargs):
|
|
if len(args) > 0 and isinstance(args[0], program.Tape.Register):
|
|
if args[0].is_gf2n:
|
|
instructions_base.set_global_instruction_type('gf2n')
|
|
else:
|
|
instructions_base.set_global_instruction_type('modp')
|
|
res = function(*args, **kwargs)
|
|
instructions_base.reset_global_instruction_type()
|
|
else:
|
|
res = function(*args, **kwargs)
|
|
return res
|
|
instruction_typed_function.__name__ = function.__name__
|
|
return instruction_typed_function
|
|
|
|
|
|
def print_str(s, *args):
|
|
""" Print a string, with optional args for adding variables/registers with %s """
|
|
def print_plain_str(ss):
|
|
""" Print a plain string (no custom formatting options) """
|
|
i = 1
|
|
while 4*i < len(ss):
|
|
print_char4(ss[4*(i-1):4*i])
|
|
i += 1
|
|
i = 4*(i-1)
|
|
while i < len(ss):
|
|
print_char(ss[i])
|
|
i += 1
|
|
|
|
if len(args) != s.count('%s'):
|
|
raise CompilerError('Incorrect number of arguments for string format:', s)
|
|
substrings = s.split('%s')
|
|
for i,ss in enumerate(substrings):
|
|
print_plain_str(ss)
|
|
if i < len(args):
|
|
if isinstance(args[i], MemValue):
|
|
val = args[i].read()
|
|
else:
|
|
val = args[i]
|
|
if isinstance(val, program.Tape.Register):
|
|
if val.is_clear:
|
|
val.print_reg_plain()
|
|
else:
|
|
raise CompilerError('Cannot print secret value:', args[i])
|
|
elif isinstance(val, cfix):
|
|
val.print_plain()
|
|
elif isinstance(val, sfix) or isinstance(val, sfloat):
|
|
raise CompilerError('Cannot print secret value:', args[i])
|
|
elif isinstance(val, cfloat):
|
|
val.print_float_plain()
|
|
elif isinstance(val, list):
|
|
print_str('[' + ', '.join('%s' for i in range(len(val))) + ']', *val)
|
|
else:
|
|
try:
|
|
val.output()
|
|
except AttributeError:
|
|
print_plain_str(str(val))
|
|
|
|
def print_ln(s='', *args):
|
|
""" Print line, with optional args for adding variables/registers with %s """
|
|
print_str(s, *args)
|
|
print_char('\n')
|
|
|
|
def print_ln_if(cond, s):
|
|
if util.is_constant(cond):
|
|
if cond:
|
|
print_ln(s)
|
|
else:
|
|
s += '\n'
|
|
while s:
|
|
cond.print_if(s[:4])
|
|
s = s[4:]
|
|
|
|
def runtime_error(msg='', *args):
|
|
""" Print an error message and abort the runtime. """
|
|
print_str('User exception: ')
|
|
print_ln(msg, *args)
|
|
crash()
|
|
|
|
def public_input():
|
|
res = regint()
|
|
pubinput(res)
|
|
return res
|
|
|
|
# mostly obsolete functions
|
|
# use the equivalent from types.py
|
|
|
|
def load_int(value, size=None):
|
|
return regint(value, size=size)
|
|
|
|
def load_int_to_secret(value, size=None):
|
|
return sint(value, size=size)
|
|
|
|
def load_int_to_secret_vector(vector):
|
|
res = sint(size=len(vector))
|
|
for i,val in enumerate(vector):
|
|
ldsi(res[i], val)
|
|
return res
|
|
|
|
@vectorize
|
|
def load_float_to_secret(value, sec=40):
|
|
def _bit_length(x):
|
|
return len(bin(x).lstrip('-0b'))
|
|
|
|
num,den = value.as_integer_ratio()
|
|
exp = int(round(math.log(den, 2)))
|
|
|
|
nbits = _bit_length(num)
|
|
if nbits > sfloat.vlen:
|
|
num >>= (nbits - sfloat.vlen)
|
|
exp -= (nbits - sfloat.vlen)
|
|
elif nbits < sfloat.vlen:
|
|
num <<= (sfloat.vlen - nbits)
|
|
exp += (sfloat.vlen - nbits)
|
|
|
|
if _bit_length(exp) > sfloat.plen:
|
|
raise CompilerException('Cannot load floating point to secret: overflow')
|
|
if num < 0:
|
|
s = load_int_to_secret(1)
|
|
z = load_int_to_secret(0)
|
|
else:
|
|
s = load_int_to_secret(0)
|
|
if num == 0:
|
|
z = load_int_to_secret(1)
|
|
else:
|
|
z = load_int_to_secret(0)
|
|
v = load_int_to_secret(num)
|
|
p = load_int_to_secret(exp)
|
|
return sfloat(v, p, s, z)
|
|
|
|
def load_clear_mem(address):
|
|
return cint.load_mem(address)
|
|
|
|
def load_secret_mem(address):
|
|
return sint.load_mem(address)
|
|
|
|
def load_mem(address, value_type):
|
|
if value_type in _types:
|
|
value_type = _types[value_type]
|
|
return value_type.load_mem(address)
|
|
|
|
@vectorize
|
|
def store_in_mem(value, address):
|
|
if isinstance(value, int):
|
|
value = load_int(value)
|
|
try:
|
|
value.store_in_mem(address)
|
|
except AttributeError:
|
|
# legacy
|
|
if value.is_clear:
|
|
if isinstance(address, cint):
|
|
stmci(value, address)
|
|
else:
|
|
stmc(value, address)
|
|
else:
|
|
if isinstance(address, cint):
|
|
stmsi(value, address)
|
|
else:
|
|
stms(value, address)
|
|
|
|
@set_instruction_type
|
|
@vectorize
|
|
def reveal(secret):
|
|
try:
|
|
return secret.reveal()
|
|
except AttributeError:
|
|
if secret.is_gf2n:
|
|
res = cgf2n()
|
|
else:
|
|
res = cint()
|
|
instructions.asm_open(res, secret)
|
|
return res
|
|
|
|
@vectorize
|
|
def compare_secret(a, b, length, sec=40):
|
|
res = sint()
|
|
instructions.lts(res, a, b, length, sec)
|
|
|
|
def get_input_from(player, size=None):
|
|
return sint.get_input_from(player, size=size)
|
|
|
|
def get_random_triple(size=None):
|
|
return sint.get_random_triple(size=size)
|
|
|
|
def get_random_bit(size=None):
|
|
return sint.get_random_bit(size=size)
|
|
|
|
def get_random_square(size=None):
|
|
return sint.get_random_square(size=size)
|
|
|
|
def get_random_inverse(size=None):
|
|
return sint.get_random_inverse(size=size)
|
|
|
|
def get_random_int(bits, size=None):
|
|
return sint.get_random_int(bits, size=size)
|
|
|
|
@vectorize
|
|
def get_thread_number():
|
|
res = regint()
|
|
ldtn(res)
|
|
return res
|
|
|
|
@vectorize
|
|
def get_arg():
|
|
res = regint()
|
|
ldarg(res)
|
|
return res
|
|
|
|
def make_array(l):
|
|
if isinstance(l, program.Tape.Register):
|
|
res = Array(1, type(l))
|
|
res[0] = l
|
|
else:
|
|
l = list(l)
|
|
res = Array(len(l), type(l[0]) if l else cint)
|
|
res.assign(l)
|
|
return res
|
|
|
|
|
|
class FunctionTapeCall:
|
|
def __init__(self, thread, base, bases):
|
|
self.thread = thread
|
|
self.base = base
|
|
self.bases = bases
|
|
def start(self):
|
|
self.thread.start(self.base)
|
|
return self
|
|
def join(self):
|
|
self.thread.join()
|
|
instructions.program.free(self.base, 'ci')
|
|
for reg_type,addr in self.bases.iteritems():
|
|
get_program().free(addr, reg_type.reg_type)
|
|
|
|
class Function:
|
|
def __init__(self, function, name=None, compile_args=[]):
|
|
self.type_args = {}
|
|
self.function = function
|
|
self.name = name
|
|
if name is None:
|
|
self.name = self.function.__name__ + '-' + str(id(function))
|
|
self.compile_args = compile_args
|
|
def __call__(self, *args):
|
|
args = tuple(arg.read() if isinstance(arg, MemValue) else arg for arg in args)
|
|
get_reg_type = lambda x: regint if isinstance(x, (int, long)) else type(x)
|
|
if len(args) not in self.type_args:
|
|
# first call
|
|
type_args = collections.defaultdict(list)
|
|
for i,arg in enumerate(args):
|
|
type_args[get_reg_type(arg)].append(i)
|
|
def wrapped_function(*compile_args):
|
|
base = get_arg()
|
|
bases = dict((t, regint.load_mem(base + i)) \
|
|
for i,t in enumerate(type_args))
|
|
runtime_args = [None] * len(args)
|
|
for t,i_args in type_args.iteritems():
|
|
for i,i_arg in enumerate(i_args):
|
|
runtime_args[i_arg] = t.load_mem(bases[t] + i)
|
|
return self.function(*(list(compile_args) + runtime_args))
|
|
self.on_first_call(wrapped_function)
|
|
self.type_args[len(args)] = type_args
|
|
type_args = self.type_args[len(args)]
|
|
base = instructions.program.malloc(len(type_args), 'ci')
|
|
bases = dict((t, get_program().malloc(len(type_args[t]), t)) \
|
|
for t in type_args)
|
|
for i,reg_type in enumerate(type_args):
|
|
store_in_mem(bases[reg_type], base + i)
|
|
for j,i_arg in enumerate(type_args[reg_type]):
|
|
if get_reg_type(args[i_arg]) != reg_type:
|
|
raise CompilerError('type mismatch')
|
|
store_in_mem(args[i_arg], bases[reg_type] + j)
|
|
return self.on_call(base, bases)
|
|
|
|
class FunctionTape(Function):
|
|
# not thread-safe
|
|
def on_first_call(self, wrapped_function):
|
|
self.thread = MPCThread(wrapped_function, self.name,
|
|
args=self.compile_args)
|
|
def on_call(self, base, bases):
|
|
return FunctionTapeCall(self.thread, base, bases)
|
|
|
|
def function_tape(function):
|
|
return FunctionTape(function)
|
|
|
|
def function_tape_with_compile_args(*args):
|
|
def wrapper(function):
|
|
return FunctionTape(function, compile_args=args)
|
|
return wrapper
|
|
|
|
def memorize(x):
|
|
if isinstance(x, (tuple, list)):
|
|
return tuple(memorize(i) for i in x)
|
|
else:
|
|
return MemValue(x)
|
|
|
|
def unmemorize(x):
|
|
if isinstance(x, (tuple, list)):
|
|
return tuple(unmemorize(i) for i in x)
|
|
else:
|
|
return x.read()
|
|
|
|
class FunctionBlock(Function):
|
|
def on_first_call(self, wrapped_function):
|
|
old_block = get_tape().active_basicblock
|
|
parent_node = get_tape().req_node
|
|
get_tape().open_scope(lambda x: x[0], None, 'begin-' + self.name)
|
|
block = get_tape().active_basicblock
|
|
block.alloc_pool = defaultdict(set)
|
|
del parent_node.children[-1]
|
|
self.node = get_tape().req_node
|
|
print 'Compiling function', self.name
|
|
result = wrapped_function(*self.compile_args)
|
|
if result is not None:
|
|
self.result = memorize(result)
|
|
else:
|
|
self.result = None
|
|
print 'Done compiling function', self.name
|
|
p_return_address = get_tape().program.malloc(1, 'ci')
|
|
get_tape().function_basicblocks[block] = p_return_address
|
|
return_address = regint.load_mem(p_return_address)
|
|
get_tape().active_basicblock.set_exit(instructions.jmpi(return_address, add_to_prog=False))
|
|
self.last_sub_block = get_tape().active_basicblock
|
|
get_tape().close_scope(old_block, parent_node, 'end-' + self.name)
|
|
old_block.set_exit(instructions.jmp(0, add_to_prog=False), get_tape().active_basicblock)
|
|
self.basic_block = block
|
|
|
|
def on_call(self, base, bases):
|
|
if base is not None:
|
|
instructions.starg(regint(base))
|
|
block = self.basic_block
|
|
if block not in get_tape().function_basicblocks:
|
|
raise CompilerError('unknown function')
|
|
old_block = get_tape().active_basicblock
|
|
old_block.set_exit(instructions.jmp(0, add_to_prog=False), block)
|
|
p_return_address = get_tape().function_basicblocks[block]
|
|
return_address = get_tape().new_reg('ci')
|
|
old_block.return_address_store = instructions.ldint(return_address, 0)
|
|
instructions.stmint(return_address, p_return_address)
|
|
get_tape().start_new_basicblock(name='call-' + self.name)
|
|
get_tape().active_basicblock.set_return(old_block, self.last_sub_block)
|
|
get_tape().req_node.children.append(self.node)
|
|
if self.result is not None:
|
|
return unmemorize(self.result)
|
|
|
|
def function_block(function):
|
|
return FunctionBlock(function)
|
|
|
|
def function_block_with_compile_args(*args):
|
|
def wrapper(function):
|
|
return FunctionBlock(function, compile_args=args)
|
|
return wrapper
|
|
|
|
def method_block(function):
|
|
# If you use this, make sure to use MemValue for all member
|
|
# variables.
|
|
compiled_functions = {}
|
|
def wrapper(self, *args):
|
|
if self in compiled_functions:
|
|
return compiled_functions[self](*args)
|
|
else:
|
|
name = '%s-%s-%d' % (type(self).__name__, function.__name__, \
|
|
id(self))
|
|
block = FunctionBlock(function, name=name, compile_args=(self,))
|
|
compiled_functions[self] = block
|
|
return block(*args)
|
|
return wrapper
|
|
|
|
def cond_swap(x,y):
|
|
b = x < y
|
|
if isinstance(x, sfloat):
|
|
res = ([], [])
|
|
for i,j in enumerate(('v','p','z','s')):
|
|
xx = x.__getattribute__(j)
|
|
yy = y.__getattribute__(j)
|
|
bx = b * xx
|
|
by = b * yy
|
|
res[0].append(bx + yy - by)
|
|
res[1].append(xx - bx + by)
|
|
return sfloat(*res[0]), sfloat(*res[1])
|
|
bx = b * x
|
|
by = b * y
|
|
return bx + y - by, x - bx + by
|
|
|
|
def sort(a):
|
|
res = a
|
|
|
|
for i in range(len(a)):
|
|
for j in reversed(range(i)):
|
|
res[j], res[j+1] = cond_swap(res[j], res[j+1])
|
|
|
|
return res
|
|
|
|
def odd_even_merge(a):
|
|
if len(a) == 2:
|
|
a[0], a[1] = cond_swap(a[0], a[1])
|
|
else:
|
|
even = a[::2]
|
|
odd = a[1::2]
|
|
odd_even_merge(even)
|
|
odd_even_merge(odd)
|
|
a[0] = even[0]
|
|
for i in range(1, len(a) / 2):
|
|
a[2*i-1], a[2*i] = cond_swap(odd[i-1], even[i])
|
|
a[-1] = odd[-1]
|
|
|
|
def odd_even_merge_sort(a):
|
|
if len(a) == 1:
|
|
return
|
|
elif len(a) % 2 == 0:
|
|
lower = a[:len(a)/2]
|
|
upper = a[len(a)/2:]
|
|
odd_even_merge_sort(lower)
|
|
odd_even_merge_sort(upper)
|
|
a[:] = lower + upper
|
|
odd_even_merge(a)
|
|
else:
|
|
raise CompilerError('Length of list must be power of two')
|
|
|
|
def chunky_odd_even_merge_sort(a):
|
|
for i,j in enumerate(a):
|
|
j.store_in_mem(i * j.sizeof())
|
|
l = 1
|
|
while l < len(a):
|
|
l *= 2
|
|
k = 1
|
|
while k < l:
|
|
k *= 2
|
|
def round():
|
|
for i in range(len(a)):
|
|
a[i] = type(a[i]).load_mem(i * a[i].sizeof())
|
|
for i in range(len(a) / l):
|
|
for j in range(l / k):
|
|
base = i * l + j
|
|
step = l / k
|
|
if k == 2:
|
|
a[base], a[base+step] = cond_swap(a[base], a[base+step])
|
|
else:
|
|
b = a[base:base+k*step:step]
|
|
for m in range(base + step, base + (k - 1) * step, 2 * step):
|
|
a[m], a[m+step] = cond_swap(a[m], a[m+step])
|
|
for i in range(len(a)):
|
|
a[i].store_in_mem(i * a[i].sizeof())
|
|
chunk = MPCThread(round, 'sort-%d-%d-%03x' % (l,k,random.randrange(256**3)))
|
|
chunk.start()
|
|
chunk.join()
|
|
#round()
|
|
for i in range(len(a)):
|
|
a[i] = type(a[i]).load_mem(i * a[i].sizeof())
|
|
|
|
def chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7, use_chunk_wraps=False):
|
|
if n is None:
|
|
n = len(a)
|
|
a_base = instructions.program.malloc(n, 's')
|
|
for i,j in enumerate(a):
|
|
store_in_mem(j, a_base + i)
|
|
instructions.program.restart_main_thread()
|
|
else:
|
|
a_base = a
|
|
tmp_base = instructions.program.malloc(n, 's')
|
|
chunks = {}
|
|
threads = []
|
|
|
|
def run_threads():
|
|
for thread in threads:
|
|
thread.start()
|
|
for thread in threads:
|
|
thread.join()
|
|
del threads[:]
|
|
|
|
def run_chunk(size, base):
|
|
if size not in chunks:
|
|
def swap_list(list_base):
|
|
for i in range(size / 2):
|
|
base = list_base + 2 * i
|
|
x, y = cond_swap(load_secret_mem(base),
|
|
load_secret_mem(base + 1))
|
|
store_in_mem(x, base)
|
|
store_in_mem(y, base + 1)
|
|
chunks[size] = FunctionTape(swap_list, 'sort-%d-%03x' %
|
|
(size, random.randrange(256**3)))
|
|
return chunks[size](base)
|
|
|
|
def run_round(size):
|
|
# minimize number of chunk sizes
|
|
n_chunks = int(math.ceil(1.0 * size / max_chunk_size))
|
|
lower_size = size / n_chunks / 2 * 2
|
|
n_lower_size = n_chunks - (size - n_chunks * lower_size) / 2
|
|
# print len(to_swap) == lower_size * n_lower_size + \
|
|
# (lower_size + 2) * (n_chunks - n_lower_size), \
|
|
# len(to_swap), n_chunks, lower_size, n_lower_size
|
|
base = 0
|
|
round_threads = []
|
|
for i in range(n_lower_size):
|
|
round_threads.append(run_chunk(lower_size, tmp_base + base))
|
|
base += lower_size
|
|
for i in range(n_chunks - n_lower_size):
|
|
round_threads.append(run_chunk(lower_size + 2, tmp_base + base))
|
|
base += lower_size + 2
|
|
run_threads_in_rounds(round_threads)
|
|
|
|
postproc_chunks = []
|
|
wrap_chunks = {}
|
|
post_threads = []
|
|
pre_threads = []
|
|
|
|
def load_and_store(x, y, to_right):
|
|
if to_right:
|
|
store_in_mem(load_secret_mem(x), y)
|
|
else:
|
|
store_in_mem(load_secret_mem(y), x)
|
|
|
|
def run_setup(k, a_addr, step, tmp_addr):
|
|
if k == 2:
|
|
def mem_op(preproc, a_addr, step, tmp_addr):
|
|
load_and_store(a_addr, tmp_addr, preproc)
|
|
load_and_store(a_addr + step, tmp_addr + 1, preproc)
|
|
res = 2
|
|
else:
|
|
def mem_op(preproc, a_addr, step, tmp_addr):
|
|
instructions.program.curr_tape.merge_opens = False
|
|
# for i,m in enumerate(range(a_addr + step, a_addr + (k - 1) * step, step)):
|
|
for i in range(k - 2):
|
|
m = a_addr + step + i * step
|
|
load_and_store(m, tmp_addr + i, preproc)
|
|
res = k - 2
|
|
if not use_chunk_wraps or k <= 4:
|
|
mem_op(True, a_addr, step, tmp_addr)
|
|
postproc_chunks.append((mem_op, (a_addr, step, tmp_addr)))
|
|
else:
|
|
if k not in wrap_chunks:
|
|
pre_chunk = FunctionTape(mem_op, 'pre-%d-%03x' % (k,random.randrange(256**3)),
|
|
compile_args=[True])
|
|
post_chunk = FunctionTape(mem_op, 'post-%d-%03x' % (k,random.randrange(256**3)),
|
|
compile_args=[False])
|
|
wrap_chunks[k] = (pre_chunk, post_chunk)
|
|
pre_chunk, post_chunk = wrap_chunks[k]
|
|
pre_threads.append(pre_chunk(a_addr, step, tmp_addr))
|
|
post_threads.append(post_chunk(a_addr, step, tmp_addr))
|
|
return res
|
|
|
|
def run_threads_in_rounds(all_threads):
|
|
for thread in all_threads:
|
|
if len(threads) == n_threads:
|
|
run_threads()
|
|
threads.append(thread)
|
|
run_threads()
|
|
del all_threads[:]
|
|
|
|
def run_postproc():
|
|
run_threads_in_rounds(post_threads)
|
|
for chunk,args in postproc_chunks:
|
|
chunk(False, *args)
|
|
postproc_chunks[:] = []
|
|
|
|
l = 1
|
|
while l < n:
|
|
l *= 2
|
|
k = 1
|
|
while k < l:
|
|
k *= 2
|
|
size = 0
|
|
instructions.program.curr_tape.merge_opens = False
|
|
for i in range(n / l):
|
|
for j in range(l / k):
|
|
base = i * l + j
|
|
step = l / k
|
|
size += run_setup(k, a_base + base, step, tmp_base + size)
|
|
run_threads_in_rounds(pre_threads)
|
|
run_round(size)
|
|
run_postproc()
|
|
|
|
if isinstance(a, list):
|
|
instructions.program.restart_main_thread()
|
|
for i in range(n):
|
|
a[i] = load_secret_mem(a_base + i)
|
|
instructions.program.free(a_base, 's')
|
|
instructions.program.free(tmp_base, 's')
|
|
|
|
def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7):
|
|
if n is None:
|
|
n = len(a)
|
|
a_base = instructions.program.malloc(n, 's')
|
|
for i,j in enumerate(a):
|
|
store_in_mem(j, a_base + i)
|
|
instructions.program.restart_main_thread()
|
|
else:
|
|
a_base = a
|
|
tmp_base = instructions.program.malloc(n, 's')
|
|
tmp_i = instructions.program.malloc(1, 'ci')
|
|
chunks = {}
|
|
threads = []
|
|
|
|
def run_threads():
|
|
for thread in threads:
|
|
thread.start()
|
|
for thread in threads:
|
|
thread.join()
|
|
del threads[:]
|
|
|
|
def run_threads_in_rounds(all_threads):
|
|
for thread in all_threads:
|
|
if len(threads) == n_threads:
|
|
run_threads()
|
|
threads.append(thread)
|
|
run_threads()
|
|
del all_threads[:]
|
|
|
|
def run_chunk(size, base):
|
|
if size not in chunks:
|
|
def swap_list(list_base):
|
|
for i in range(size / 2):
|
|
base = list_base + 2 * i
|
|
x, y = cond_swap(load_secret_mem(base),
|
|
load_secret_mem(base + 1))
|
|
store_in_mem(x, base)
|
|
store_in_mem(y, base + 1)
|
|
chunks[size] = FunctionTape(swap_list, 'sort-%d-%03x' %
|
|
(size, random.randrange(256**3)))
|
|
return chunks[size](base)
|
|
|
|
def run_round(size):
|
|
# minimize number of chunk sizes
|
|
n_chunks = int(math.ceil(1.0 * size / max_chunk_size))
|
|
lower_size = size / n_chunks / 2 * 2
|
|
n_lower_size = n_chunks - (size - n_chunks * lower_size) / 2
|
|
# print len(to_swap) == lower_size * n_lower_size + \
|
|
# (lower_size + 2) * (n_chunks - n_lower_size), \
|
|
# len(to_swap), n_chunks, lower_size, n_lower_size
|
|
base = 0
|
|
round_threads = []
|
|
for i in range(n_lower_size):
|
|
round_threads.append(run_chunk(lower_size, tmp_base + base))
|
|
base += lower_size
|
|
for i in range(n_chunks - n_lower_size):
|
|
round_threads.append(run_chunk(lower_size + 2, tmp_base + base))
|
|
base += lower_size + 2
|
|
run_threads_in_rounds(round_threads)
|
|
|
|
l = 1
|
|
while l < n:
|
|
l *= 2
|
|
k = 1
|
|
while k < l:
|
|
k *= 2
|
|
def load_and_store(x, y):
|
|
if to_tmp:
|
|
store_in_mem(load_secret_mem(x), y)
|
|
else:
|
|
store_in_mem(load_secret_mem(y), x)
|
|
def outer(i):
|
|
def inner(j):
|
|
base = j
|
|
step = l / k
|
|
if k == 2:
|
|
tmp_addr = regint.load_mem(tmp_i)
|
|
load_and_store(base, tmp_addr)
|
|
load_and_store(base + step, tmp_addr + 1)
|
|
store_in_mem(tmp_addr + 2, tmp_i)
|
|
else:
|
|
def inner2(m):
|
|
tmp_addr = regint.load_mem(tmp_i)
|
|
load_and_store(m, tmp_addr)
|
|
store_in_mem(tmp_addr + 1, tmp_i)
|
|
range_loop(inner2, base + step, base + (k - 1) * step, step)
|
|
range_loop(inner, a_base + i * l, a_base + i * l + l / k)
|
|
instructions.program.curr_tape.merge_opens = False
|
|
to_tmp = True
|
|
store_in_mem(tmp_base, tmp_i)
|
|
range_loop(outer, n / l)
|
|
if k == 2:
|
|
run_round(n)
|
|
else:
|
|
run_round(n / k * (k - 2))
|
|
instructions.program.curr_tape.merge_opens = False
|
|
to_tmp = False
|
|
store_in_mem(tmp_base, tmp_i)
|
|
range_loop(outer, n / l)
|
|
|
|
if isinstance(a, list):
|
|
instructions.program.restart_main_thread()
|
|
for i in range(n):
|
|
a[i] = load_secret_mem(a_base + i)
|
|
instructions.program.free(a_base, 's')
|
|
instructions.program.free(tmp_base, 's')
|
|
instructions.program.free(tmp_i, 'ci')
|
|
|
|
|
|
def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32):
|
|
l = sorted_length
|
|
while l < len(a):
|
|
l *= 2
|
|
k = 1
|
|
while k < l:
|
|
k *= 2
|
|
n_outer = len(a) / l
|
|
n_inner = l / k
|
|
n_innermost = 1 if k == 2 else k / 2 - 1
|
|
@for_range_parallel(n_parallel / n_innermost / n_inner, n_outer)
|
|
def loop(i):
|
|
@for_range_parallel(n_parallel / n_innermost, n_inner)
|
|
def inner(j):
|
|
base = i*l + j
|
|
step = l/k
|
|
if k == 2:
|
|
a[base], a[base+step] = cond_swap(a[base], a[base+step])
|
|
else:
|
|
@for_range_parallel(n_parallel, n_innermost)
|
|
def f(i):
|
|
m1 = step + i * 2 * step
|
|
m2 = m1 + base
|
|
a[m2], a[m2+step] = cond_swap(a[m2], a[m2+step])
|
|
|
|
def mergesort(A):
|
|
B = Array(len(A), sint)
|
|
|
|
def merge(i_left, i_right, i_end):
|
|
i0 = MemValue(i_left)
|
|
i1 = MemValue(i_right)
|
|
@for_range(i_left, i_end)
|
|
def loop(j):
|
|
if_then(and_(lambda: i0 < i_right,
|
|
or_(lambda: i1 >= i_end,
|
|
lambda: regint(reveal(A[i0] <= A[i1])))))
|
|
B[j] = A[i0]
|
|
i0.iadd(1)
|
|
else_then()
|
|
B[j] = A[i1]
|
|
i1.iadd(1)
|
|
end_if()
|
|
|
|
width = MemValue(1)
|
|
@do_while
|
|
def width_loop():
|
|
@for_range(0, len(A), 2 * width)
|
|
def merge_loop(i):
|
|
merge(i, i + width, i + 2 * width)
|
|
A.assign(B)
|
|
width.imul(2)
|
|
return width < len(A)
|
|
|
|
def range_loop(loop_body, start, stop=None, step=None):
|
|
if stop is None:
|
|
stop = start
|
|
start = 0
|
|
if step is None:
|
|
step = 1
|
|
def loop_fn(i):
|
|
loop_body(i)
|
|
return i + step
|
|
if isinstance(step, int):
|
|
if step > 0:
|
|
condition = lambda x: x < stop
|
|
elif step < 0:
|
|
condition = lambda x: x > stop
|
|
else:
|
|
raise CompilerError('step must not be zero')
|
|
else:
|
|
b = step > 0
|
|
condition = lambda x: b * (x < stop) + (1 - b) * (x > stop)
|
|
while_loop(loop_fn, condition, start)
|
|
if isinstance(start, int) and isinstance(stop, int) \
|
|
and isinstance(step, int):
|
|
# known loop count
|
|
if condition(start):
|
|
get_tape().req_node.children[-1].aggregator = \
|
|
lambda x: ((stop - start) / step) * x[0]
|
|
|
|
def for_range(start, stop=None, step=None):
|
|
def decorator(loop_body):
|
|
range_loop(loop_body, start, stop, step)
|
|
return loop_body
|
|
return decorator
|
|
|
|
def for_range_parallel(n_parallel, n_loops):
|
|
return map_reduce_single(n_parallel, n_loops, \
|
|
lambda *x: [], lambda *x: [])
|
|
|
|
def map_reduce_single(n_parallel, n_loops, initializer, reducer, mem_state=None):
|
|
if not isinstance(n_parallel, int):
|
|
raise CompilerException('Number of parallel executions' \
|
|
'must be constant')
|
|
n_parallel = n_parallel or 1
|
|
if mem_state is None:
|
|
# default to list of MemValues to allow varying types
|
|
mem_state = [type(x).MemValue(x) for x in initializer()]
|
|
use_array = False
|
|
else:
|
|
# use Arrays for multithread version
|
|
use_array = True
|
|
def decorator(loop_body):
|
|
if isinstance(n_loops, int):
|
|
loop_rounds = n_loops / n_parallel \
|
|
if n_parallel < n_loops else 0
|
|
else:
|
|
loop_rounds = n_loops / n_parallel
|
|
def write_state_to_memory(r):
|
|
if use_array:
|
|
mem_state.assign(r)
|
|
else:
|
|
# cannot do mem_state = [...] due to scope issue
|
|
for j,x in enumerate(r):
|
|
mem_state[j].write(x)
|
|
# will be optimized out if n_loops <= n_parallel
|
|
@for_range(loop_rounds)
|
|
def f(i):
|
|
state = tuplify(initializer())
|
|
for k in range(n_parallel):
|
|
j = i * n_parallel + k
|
|
state = reducer(tuplify(loop_body(j)), state)
|
|
r = reducer(mem_state, state)
|
|
write_state_to_memory(r)
|
|
if isinstance(n_loops, int):
|
|
state = mem_state
|
|
for j in range(loop_rounds * n_parallel, n_loops):
|
|
state = reducer(tuplify(loop_body(j)), state)
|
|
else:
|
|
@for_range(loop_rounds * n_parallel, n_loops)
|
|
def f(j):
|
|
r = reducer(tuplify(loop_body(j)), mem_state)
|
|
write_state_to_memory(r)
|
|
state = mem_state
|
|
for i,x in enumerate(state):
|
|
if use_array:
|
|
mem_state[i] = x
|
|
else:
|
|
mem_state[i].write(x)
|
|
def returner():
|
|
return untuplify(tuple(state))
|
|
return returner
|
|
return decorator
|
|
|
|
def for_range_multithread(n_threads, n_parallel, n_loops, thread_mem_req={}):
|
|
return map_reduce(n_threads, n_parallel, n_loops, \
|
|
lambda *x: [], lambda *x: [], thread_mem_req)
|
|
|
|
def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
|
|
thread_mem_req={}):
|
|
n_threads = n_threads or 1
|
|
if n_threads == 1 or n_loops == 1:
|
|
dec = map_reduce_single(n_parallel, n_loops, initializer, reducer)
|
|
if thread_mem_req:
|
|
thread_mem = Array(thread_mem_req[regint], regint)
|
|
return lambda loop_body: dec(lambda i: loop_body(i, thread_mem))
|
|
else:
|
|
return dec
|
|
def decorator(loop_body):
|
|
thread_rounds = n_loops / n_threads
|
|
remainder = n_loops % n_threads
|
|
for t in thread_mem_req:
|
|
if t != regint:
|
|
raise CompilerError('Not implemented for other than regint')
|
|
args = Matrix(n_threads, 2 + thread_mem_req.get(regint, 0), 'ci')
|
|
state = tuple(initializer())
|
|
def f(inc):
|
|
if thread_mem_req:
|
|
thread_mem = Array(thread_mem_req[regint], regint, \
|
|
args[get_arg()].address + 2)
|
|
mem_state = Array(len(state), type(state[0]) \
|
|
if state else cint, args[get_arg()][1])
|
|
base = args[get_arg()][0]
|
|
@map_reduce_single(n_parallel, thread_rounds + inc, \
|
|
initializer, reducer, mem_state)
|
|
def f(i):
|
|
if thread_mem_req:
|
|
return loop_body(base + i, thread_mem)
|
|
else:
|
|
return loop_body(base + i)
|
|
prog = get_program()
|
|
threads = []
|
|
if thread_rounds:
|
|
tape = prog.new_tape(f, (0,), 'multithread')
|
|
for i in range(n_threads - remainder):
|
|
mem_state = make_array(initializer())
|
|
args[remainder + i][0] = i * thread_rounds
|
|
if len(mem_state):
|
|
args[remainder + i][1] = mem_state.address
|
|
threads.append(prog.run_tape(tape, remainder + i))
|
|
if remainder:
|
|
tape1 = prog.new_tape(f, (1,), 'multithread1')
|
|
for i in range(remainder):
|
|
mem_state = make_array(initializer())
|
|
args[i][0] = (n_threads - remainder + i) * thread_rounds + i
|
|
if len(mem_state):
|
|
args[i][1] = mem_state.address
|
|
threads.append(prog.run_tape(tape1, i))
|
|
for thread in threads:
|
|
prog.join_tape(thread)
|
|
if state:
|
|
if thread_rounds:
|
|
for i in range(n_threads - remainder):
|
|
state = reducer(Array(len(state), type(state[0]), \
|
|
args[remainder + i][1]), state)
|
|
if remainder:
|
|
for i in range(remainder):
|
|
state = reducer(Array(len(state), type(state[0]).reg_type, \
|
|
args[i][1]), state)
|
|
def returner():
|
|
return untuplify(state)
|
|
return returner
|
|
return decorator
|
|
|
|
def map_sum(n_threads, n_parallel, n_loops, n_items, value_types):
|
|
value_types = tuplify(value_types)
|
|
if len(value_types) == 1:
|
|
value_types *= n_items
|
|
elif len(value_types) != n_items:
|
|
raise CompilerError('Incorrect number of value_types.')
|
|
initializer = lambda: [t(0) for t in value_types]
|
|
def summer(x,y):
|
|
return tuple(a + b for a,b in zip(x,y))
|
|
return map_reduce(n_threads, n_parallel, n_loops, initializer, summer)
|
|
|
|
def foreach_enumerate(a):
|
|
for x in a:
|
|
get_program().public_input(' '.join(str(y) for y in tuplify(x)))
|
|
def decorator(loop_body):
|
|
@for_range(len(a))
|
|
def f(i):
|
|
loop_body(i, *(public_input() for j in range(len(tuplify(a[0])))))
|
|
return f
|
|
return decorator
|
|
|
|
def while_loop(loop_body, condition, arg):
|
|
if not callable(condition):
|
|
raise CompilerError('Condition must be callable')
|
|
# store arg in stack
|
|
pre_condition = condition(arg)
|
|
if not isinstance(pre_condition, (bool,int)) or pre_condition:
|
|
pushint(arg if isinstance(arg,regint) else regint(arg))
|
|
def loop_fn():
|
|
result = loop_body(regint.pop())
|
|
pushint(result)
|
|
return condition(result)
|
|
if_statement(pre_condition, lambda: do_while(loop_fn))
|
|
regint.pop()
|
|
|
|
def while_do(condition, *args):
|
|
def decorator(loop_body):
|
|
while_loop(loop_body, condition, *args)
|
|
return loop_body
|
|
return decorator
|
|
|
|
def do_loop(condition, loop_fn):
|
|
# store initial condition to stack
|
|
pushint(condition if isinstance(condition,regint) else regint(condition))
|
|
def wrapped_loop():
|
|
# save condition to stack
|
|
new_cond = regint.pop()
|
|
# run the loop
|
|
condition = loop_fn(new_cond)
|
|
pushint(condition)
|
|
return condition
|
|
do_while(wrapped_loop)
|
|
regint.pop()
|
|
|
|
def do_while(loop_fn):
|
|
scope = instructions.program.curr_block
|
|
parent_node = get_tape().req_node
|
|
# possibly unknown loop count
|
|
get_tape().open_scope(lambda x: x[0].set_all(float('Inf')), \
|
|
name='begin-loop')
|
|
loop_block = instructions.program.curr_block
|
|
condition = loop_fn()
|
|
if callable(condition):
|
|
condition = condition()
|
|
branch = instructions.jmpnz(regint.conv(condition), 0, add_to_prog=False)
|
|
instructions.program.curr_block.set_exit(branch, loop_block)
|
|
get_tape().close_scope(scope, parent_node, 'end-loop')
|
|
return loop_fn
|
|
|
|
def if_then(condition):
|
|
class State: pass
|
|
state = State()
|
|
if callable(condition):
|
|
condition = condition()
|
|
state.condition = regint.conv(condition)
|
|
state.start_block = instructions.program.curr_block
|
|
state.req_child = get_tape().open_scope(lambda x: x[0].max(x[1]), \
|
|
name='if-block')
|
|
state.has_else = False
|
|
instructions.program.curr_tape.if_states.append(state)
|
|
|
|
def else_then():
|
|
try:
|
|
state = instructions.program.curr_tape.if_states[-1]
|
|
except IndexError:
|
|
raise CompilerError('No open if block')
|
|
if state.has_else:
|
|
raise CompilerError('else block already defined')
|
|
# run the else block
|
|
state.if_exit_block = instructions.program.curr_block
|
|
state.req_child.add_node(get_tape(), 'else-block')
|
|
instructions.program.curr_tape.start_new_basicblock(state.start_block, \
|
|
name='else-block')
|
|
state.else_block = instructions.program.curr_block
|
|
state.has_else = True
|
|
|
|
def end_if():
|
|
try:
|
|
state = instructions.program.curr_tape.if_states.pop()
|
|
except IndexError:
|
|
raise CompilerError('No open if/else block')
|
|
branch = instructions.jmpeqz(regint.conv(state.condition), 0, \
|
|
add_to_prog=False)
|
|
# start next block
|
|
get_tape().close_scope(state.start_block, state.req_child.parent, 'end-if')
|
|
if state.has_else:
|
|
# jump to else block if condition == 0
|
|
state.start_block.set_exit(branch, state.else_block)
|
|
# set if block to skip else
|
|
jump = instructions.jmp(0, add_to_prog=False)
|
|
state.if_exit_block.set_exit(jump, instructions.program.curr_block)
|
|
else:
|
|
# set start block's conditional jump to next block
|
|
state.start_block.set_exit(branch, instructions.program.curr_block)
|
|
# nothing to compute without else
|
|
state.req_child.aggregator = lambda x: x[0]
|
|
|
|
def if_statement(condition, if_fn, else_fn=None):
|
|
if condition is True or condition is False:
|
|
# condition known at compile time
|
|
if condition:
|
|
if_fn()
|
|
elif else_fn is not None:
|
|
else_fn()
|
|
else:
|
|
state = if_then(condition)
|
|
if_fn()
|
|
if else_fn is not None:
|
|
else_then()
|
|
else_fn()
|
|
end_if()
|
|
|
|
def if_(condition):
|
|
def decorator(body):
|
|
if_then(condition)
|
|
body()
|
|
end_if()
|
|
return decorator
|
|
|
|
def if_e(condition):
|
|
def decorator(body):
|
|
if_then(condition)
|
|
body()
|
|
return decorator
|
|
|
|
def else_(body):
|
|
else_then()
|
|
body()
|
|
end_if()
|
|
|
|
def and_(*terms):
|
|
# not thread-safe
|
|
p_res = instructions.program.malloc(1, 'ci')
|
|
for term in terms:
|
|
if_then(term())
|
|
store_in_mem(1, p_res)
|
|
for term in terms:
|
|
else_then()
|
|
store_in_mem(0, p_res)
|
|
end_if()
|
|
def load_result():
|
|
res = regint.load_mem(p_res)
|
|
instructions.program.free(p_res, 'ci')
|
|
return res
|
|
return load_result
|
|
|
|
def or_(*terms):
|
|
# not thread-safe
|
|
p_res = instructions.program.malloc(1, 'ci')
|
|
res = regint()
|
|
for term in terms:
|
|
if_then(term())
|
|
store_in_mem(1, p_res)
|
|
else_then()
|
|
store_in_mem(0, p_res)
|
|
for term in terms:
|
|
end_if()
|
|
def load_result():
|
|
res = regint.load_mem(p_res)
|
|
instructions.program.free(p_res, 'ci')
|
|
return res
|
|
return load_result
|
|
|
|
def not_(term):
|
|
return lambda: 1 - term()
|
|
|
|
def start_timer(timer_id=0):
|
|
get_tape().start_new_basicblock(name='pre-start-timer')
|
|
start(timer_id)
|
|
get_tape().start_new_basicblock(name='post-start-timer')
|
|
|
|
def stop_timer(timer_id=0):
|
|
get_tape().start_new_basicblock(name='pre-stop-timer')
|
|
stop(timer_id)
|
|
get_tape().start_new_basicblock(name='post-stop-timer')
|
|
|
|
# Fixed point ops
|
|
|
|
from math import ceil, log
|
|
from floatingpoint import PreOR, TruncPr, two_power, shift_two
|
|
|
|
def approximate_reciprocal(divisor, k, f, theta):
|
|
"""
|
|
returns aproximation of 1/divisor
|
|
where type(divisor) = cint
|
|
"""
|
|
def twos_complement(x):
|
|
bits = x.bit_decompose(k)[::-1]
|
|
bit_array = Array(k, cint)
|
|
bit_array.assign(bits)
|
|
|
|
twos_result = MemValue(cint(0))
|
|
@for_range(k)
|
|
def block(i):
|
|
val = twos_result.read()
|
|
val <<= 1
|
|
val += 1 - bit_array[i]
|
|
twos_result.write(val)
|
|
|
|
return twos_result.read() + 1
|
|
|
|
bit_array = Array(k, cint)
|
|
bits = divisor.bit_decompose(k)[::-1]
|
|
bit_array.assign(bits)
|
|
|
|
cnt_leading_zeros = MemValue(regint(0))
|
|
|
|
flag = MemValue(regint(0))
|
|
cnt_leading_zeros = MemValue(regint(0))
|
|
normalized_divisor = MemValue(divisor)
|
|
|
|
@for_range(k)
|
|
def block(i):
|
|
flag.write(flag.read() | bit_array[i] == 1)
|
|
@if_(flag.read() == 0)
|
|
def block():
|
|
cnt_leading_zeros.write(cnt_leading_zeros.read() + 1)
|
|
normalized_divisor.write(normalized_divisor << 1)
|
|
|
|
q = MemValue(two_power(k))
|
|
e = MemValue(twos_complement(normalized_divisor.read()))
|
|
|
|
qr = q.read()
|
|
er = e.read()
|
|
|
|
for i in range(theta):
|
|
qr = qr + shift_two(qr * er, k)
|
|
er = shift_two(er * er, k)
|
|
|
|
q = qr
|
|
res = shift_two(q, (2*k - 2*f - cnt_leading_zeros))
|
|
|
|
return res
|
|
|
|
|
|
def cint_cint_division(a, b, k, f):
|
|
"""
|
|
Goldschmidt method implemented with
|
|
SE aproximation:
|
|
http://stackoverflow.com/questions/2661541/picking-good-first-estimates-for-goldschmidt-division
|
|
"""
|
|
# theta can be replaced with something smaller
|
|
# for safety we assume that is the same theta from previous GS method
|
|
|
|
theta = int(ceil(log(k/3.5) / log(2)))
|
|
two = cint(2) * two_power(f)
|
|
|
|
sign_b = cint(1) - 2 * cint(b < 0)
|
|
sign_a = cint(1) - 2 * cint(a < 0)
|
|
absolute_b = b * sign_b
|
|
absolute_a = a * sign_a
|
|
w0 = approximate_reciprocal(absolute_b, k, f, theta)
|
|
A = Array(theta, cint)
|
|
B = Array(theta, cint)
|
|
W = Array(theta, cint)
|
|
|
|
A[0] = absolute_a
|
|
B[0] = absolute_b
|
|
W[0] = w0
|
|
for i in range(1, theta):
|
|
A[i] = shift_two(A[i - 1] * W[i - 1], f)
|
|
B[i] = shift_two(B[i - 1] * W[i - 1], f)
|
|
W[i] = two - B[i]
|
|
|
|
return (sign_a * sign_b) * A[theta - 1]
|
|
|
|
from Compiler.program import Program
|
|
def sint_cint_division(a, b, k, f, kappa):
|
|
"""
|
|
type(a) = sint, type(b) = cint
|
|
"""
|
|
theta = int(ceil(log(k/3.5) / log(2)))
|
|
two = cint(2) * two_power(f)
|
|
sign_b = cint(1) - 2 * cint(b < 0)
|
|
sign_a = sint(1) - 2 * sint(a < 0)
|
|
absolute_b = b * sign_b
|
|
absolute_a = a * sign_a
|
|
w0 = approximate_reciprocal(absolute_b, k, f, theta)
|
|
|
|
A = Array(theta, sint)
|
|
B = Array(theta, cint)
|
|
W = Array(theta, cint)
|
|
|
|
A[0] = absolute_a
|
|
B[0] = absolute_b
|
|
W[0] = w0
|
|
|
|
|
|
@for_range(1, theta)
|
|
def block(i):
|
|
A[i] = TruncPr(A[i - 1] * W[i - 1], 2*k, f, kappa)
|
|
temp = shift_two(B[i - 1] * W[i - 1], f)
|
|
# no reading and writing to the same variable in a for loop.
|
|
W[i] = two - temp
|
|
B[i] = temp
|
|
return (sign_a * sign_b) * A[theta - 1]
|
|
|
|
def FPDiv(a, b, k, f, kappa, simplex_flag=False):
|
|
"""
|
|
Goldschmidt method as presented in Catrina10,
|
|
"""
|
|
theta = int(ceil(log(k/3.5) / log(2)))
|
|
alpha = b.get_type(2 * k).two_power(2*f)
|
|
w = AppRcr(b, k, f, kappa, simplex_flag).extend(2 * k)
|
|
x = alpha - b.extend(2 * k) * w
|
|
|
|
y = a.extend(2 *k) * w
|
|
y = y.TruncPr(2*k, f, kappa)
|
|
|
|
for i in range(theta):
|
|
x = x.extend(2 * k)
|
|
y = y.extend(2 * k) * (alpha + x).extend(2 * k)
|
|
x = x * x
|
|
y = y.TruncPr(2*k, 2*f, kappa)
|
|
x = x.TruncPr(2*k, 2*f, kappa)
|
|
|
|
y = y.extend(2 * k) * (alpha + x).extend(2 * k)
|
|
y = y.TruncPr(k + 2*f, 2*f, kappa)
|
|
return y
|
|
def AppRcr(b, k, f, kappa, simplex_flag=False):
|
|
"""
|
|
Approximate reciprocal of [b]:
|
|
Given [b], compute [1/b]
|
|
"""
|
|
alpha = b.get_type(2 * k)(int(2.9142 * 2**k))
|
|
c, v = b.Norm(k, f, kappa, simplex_flag)
|
|
#v should be 2**{k - m} where m is the length of the bitwise repr of [b]
|
|
d = alpha - 2 * c
|
|
w = d * v
|
|
w = w.TruncPr(2 * k, 2 * (k - f))
|
|
# now w * 2 ^ {-f} should be an initial approximation of 1/b
|
|
return w
|
|
|
|
def Norm(b, k, f, kappa, simplex_flag=False):
|
|
"""
|
|
Computes secret integer values [c] and [v_prime] st.
|
|
2^{k-1} <= c < 2^k and c = b*v_prime
|
|
"""
|
|
# For simplex, we can get rid of computing abs(b)
|
|
temp = None
|
|
if simplex_flag == False:
|
|
temp = b.less_than(0, 2 * k)
|
|
elif simplex_flag == True:
|
|
temp = cint(0)
|
|
|
|
sign = 1 - 2 * temp # 1 - 2 * [b < 0]
|
|
absolute_val = sign * b
|
|
|
|
#next 2 lines actually compute the SufOR for little indian encoding
|
|
bits = absolute_val.bit_decompose(k, kappa)[::-1]
|
|
suffixes = PreOR(bits)[::-1]
|
|
|
|
z = [0] * k
|
|
for i in range(k - 1):
|
|
z[i] = suffixes[i] - suffixes[i+1]
|
|
z[k - 1] = suffixes[k-1]
|
|
|
|
#doing complicated stuff to compute v = 2^{k-m}
|
|
acc = cint(0)
|
|
for i in range(k):
|
|
acc += two_power(k-i-1) * z[i]
|
|
|
|
part_reciprocal = absolute_val * acc
|
|
signed_acc = sign * acc
|
|
|
|
return part_reciprocal, signed_acc
|