Files
MP-SPDZ/Compiler/library.py
2021-07-02 15:50:34 +10:00

1752 lines
58 KiB
Python

"""
This module defines functions directly available in high-level programs,
in particularly providing flow control and output.
"""
from Compiler.types import cint,sint,cfix,sfix,sfloat,MPCThread,Array,MemValue,cgf2n,sgf2n,_number,_mem,_register,regint,Matrix,_types, cfloat, _single, localint, personal, copy_doc
from Compiler.instructions import *
from Compiler.util import tuplify,untuplify,is_zero
from Compiler.allocator import RegintOptimizer
from Compiler import instructions,instructions_base,comparison,program,util
import inspect,math
import random
import collections
import operator
from functools import reduce
def get_program():
return instructions.program
def get_tape():
return get_program().curr_tape
def get_block():
return get_program().curr_block
def vectorize(function):
def vectorized_function(*args, **kwargs):
if len(args) > 0 and 'size' in dir(args[0]):
instructions_base.set_global_vector_size(args[0].size)
res = function(*args, **kwargs)
instructions_base.reset_global_vector_size()
elif 'size' in kwargs:
instructions_base.set_global_vector_size(kwargs['size'])
del kwargs['size']
res = function(*args, **kwargs)
instructions_base.reset_global_vector_size()
else:
res = function(*args, **kwargs)
return res
vectorized_function.__name__ = function.__name__
copy_doc(vectorized_function, function)
return vectorized_function
def set_instruction_type(function):
def instruction_typed_function(*args, **kwargs):
if len(args) > 0 and isinstance(args[0], program.Tape.Register):
if args[0].is_gf2n:
instructions_base.set_global_instruction_type('gf2n')
else:
instructions_base.set_global_instruction_type('modp')
res = function(*args, **kwargs)
instructions_base.reset_global_instruction_type()
else:
res = function(*args, **kwargs)
return res
instruction_typed_function.__name__ = function.__name__
return instruction_typed_function
def _expand_to_print(val):
return ('[' + ', '.join('%s' for i in range(len(val))) + ']',) + tuple(val)
def print_str(s, *args):
""" Print a string, with optional args for adding
variables/registers with ``%s``. """
def print_plain_str(ss):
""" Print a plain string (no custom formatting options) """
i = 1
while 4*i <= len(ss):
print_char4(ss[4*(i-1):4*i])
i += 1
i = 4*(i-1)
while i < len(ss):
print_char(ss[i])
i += 1
if len(args) != s.count('%s'):
raise CompilerError('Incorrect number of arguments for string format:', s)
substrings = s.split('%s')
for i,ss in enumerate(substrings):
print_plain_str(ss)
if i < len(args):
if isinstance(args[i], MemValue):
val = args[i].read()
else:
val = args[i]
if isinstance(val, program.Tape.Register):
if val.is_clear:
val.print_reg_plain()
else:
raise CompilerError('Cannot print secret value:', args[i])
elif isinstance(val, cfix):
val.print_plain()
elif isinstance(val, sfix) or isinstance(val, sfloat):
raise CompilerError('Cannot print secret value:', args[i])
elif isinstance(val, cfloat):
val.print_float_plain()
elif isinstance(val, (list, tuple, Array)):
print_str(*_expand_to_print(val))
else:
try:
val.output()
except AttributeError:
print_plain_str(str(val))
def print_ln(s='', *args):
""" Print line, with optional args for adding variables/registers
with ``%s``. By default only player 0 outputs, but the ``-I``
command-line option changes that.
:param s: Python string with same number of ``%s`` as length of :py:obj:`args`
:param args: list of public values (regint/cint/int/cfix/cfloat/localint)
Example:
.. code::
print_ln('a is %s.', a.reveal())
"""
print_str(s + '\n', *args)
def print_ln_if(cond, ss, *args):
""" Print line if :py:obj:`cond` is true. The further arguments
are treated as in :py:func:`print_str`/:py:func:`print_ln`.
:param cond: regint/cint/int/localint
:param ss: Python string
:param args: list of public values
Example:
.. code::
print_ln_if(get_player_id() == 0, 'Player 0 here')
"""
print_str_if(cond, ss + '\n', *args)
def print_str_if(cond, ss, *args):
""" Print string conditionally. See :py:func:`print_ln_if` for details. """
if util.is_constant(cond):
if cond:
print_ln(ss, *args)
else:
subs = ss.split('%s')
assert len(subs) == len(args) + 1
if isinstance(cond, localint):
cond = cond._v
cond = cint.conv(cond)
for i, s in enumerate(subs):
if i != 0:
val = args[i - 1]
try:
val.output_if(cond)
except:
if isinstance(val, (list, tuple, Array)):
print_str_if(cond, *_expand_to_print(val))
else:
print_str_if(cond, str(val))
s += '\0' * ((-len(s)) % 4)
while s:
cond.print_if(s[:4])
s = s[4:]
def print_ln_to(player, ss, *args):
""" Print line at :py:obj:`player` only. Note that printing is
disabled by default except at player 0. Activate interactive mode
with `-I` to enable it for all players.
:param player: int
:param ss: Python string
:param args: list of values known to :py:obj:`player`
Example::
print_ln_to(player, 'output for %s: %s', player, x.reveal_to(player))
"""
cond = player == get_player_id()
new_args = []
for arg in args:
if isinstance(arg, personal):
if util.is_constant(arg.player) ^ util.is_constant(player):
match = False
else:
if util.is_constant(player):
match = arg.player == player
else:
match = id(arg.player) == id(player)
if not match:
raise CompilerError('player mismatch in personal printing')
new_args.append(arg._v)
else:
new_args.append(arg)
print_ln_if(cond, ss, *new_args)
def print_float_precision(n):
""" Set the precision for floating-point printing.
:param n: number of digits (int) """
print_float_prec(n)
def runtime_error(msg='', *args):
""" Print an error message and abort the runtime.
Parameters work as in :py:func:`print_ln` """
print_str('User exception: ')
print_ln(msg, *args)
crash()
def public_input():
""" Public input read from ``Programs/Public-Input/<progname>``. """
res = cint()
pubinput(res)
return res
# mostly obsolete functions
# use the equivalent from types.py
@vectorize
def store_in_mem(value, address):
if isinstance(value, int):
value = regint(value)
try:
value.store_in_mem(address)
except AttributeError:
# legacy
if value.is_clear:
if isinstance(address, cint):
stmci(value, address)
else:
stmc(value, address)
else:
if isinstance(address, cint):
stmsi(value, address)
else:
stms(value, address)
@set_instruction_type
@vectorize
def reveal(secret):
try:
return secret.reveal()
except AttributeError:
if secret.is_gf2n:
res = cgf2n()
else:
res = cint()
instructions.asm_open(res, secret)
return res
@vectorize
def get_thread_number():
""" Returns the thread number. """
res = regint()
ldtn(res)
return res
@vectorize
def get_arg():
""" Returns the thread argument. """
res = regint()
ldarg(res)
return res
def make_array(l):
if isinstance(l, program.Tape.Register):
res = Array(1, type(l))
res[0] = l
else:
l = list(l)
res = Array(len(l), type(l[0]) if l else cint)
res.assign(l)
return res
class FunctionTapeCall:
def __init__(self, thread, base, bases):
self.thread = thread
self.base = base
self.bases = bases
def start(self):
self.thread.start(self.base)
return self
def join(self):
self.thread.join()
instructions.program.free(self.base, 'ci')
for reg_type,addr in self.bases.items():
get_program().free(addr, reg_type.reg_type)
class Function:
def __init__(self, function, name=None, compile_args=[]):
self.type_args = {}
self.function = function
self.name = name
if name is None:
self.name = self.function.__name__
self.compile_args = compile_args
def __call__(self, *args):
args = tuple(arg.read() if isinstance(arg, MemValue) else arg for arg in args)
from .types import _types
get_reg_type = lambda x: \
regint if isinstance(x, int) else _types.get(x.reg_type, type(x))
if len(args) not in self.type_args:
# first call
type_args = collections.defaultdict(list)
for i,arg in enumerate(args):
type_args[get_reg_type(arg)].append(i)
def wrapped_function(*compile_args):
base = get_arg()
bases = dict((t, regint.load_mem(base + i)) \
for i,t in enumerate(sorted(type_args,
key=lambda x:
x.reg_type)))
runtime_args = [None] * len(args)
for t in sorted(type_args, key=lambda x: x.reg_type):
i = 0
for i_arg in type_args[t]:
runtime_args[i_arg] = t.load_mem(bases[t] + i)
i += util.mem_size(t)
return self.function(*(list(compile_args) + runtime_args))
self.on_first_call(wrapped_function)
self.type_args[len(args)] = type_args
type_args = self.type_args[len(args)]
base = instructions.program.malloc(len(type_args), 'ci')
bases = dict((t, get_program().malloc(len(type_args[t]), t)) \
for t in type_args)
for i,reg_type in enumerate(sorted(type_args,
key=lambda x: x.reg_type)):
store_in_mem(bases[reg_type], base + i)
j = 0
for i_arg in type_args[reg_type]:
if get_reg_type(args[i_arg]) != reg_type:
raise CompilerError('type mismatch: "%s" not of type "%s"' %
(args[i_arg], reg_type))
store_in_mem(args[i_arg], bases[reg_type] + j)
j += util.mem_size(reg_type)
return self.on_call(base, bases)
class FunctionTape(Function):
# not thread-safe
def __init__(self, function, name=None, compile_args=[],
single_thread=False):
Function.__init__(self, function, name, compile_args)
self.single_thread = single_thread
def on_first_call(self, wrapped_function):
self.thread = MPCThread(wrapped_function, self.name,
args=self.compile_args,
single_thread=self.single_thread)
def on_call(self, base, bases):
return FunctionTapeCall(self.thread, base, bases)
def function_tape(function):
return FunctionTape(function)
def function_tape_with_compile_args(*args):
def wrapper(function):
return FunctionTape(function, compile_args=args)
return wrapper
def single_thread_function_tape(function):
return FunctionTape(function, single_thread=True)
def memorize(x):
if isinstance(x, (tuple, list)):
return tuple(memorize(i) for i in x)
else:
return MemValue(x)
def unmemorize(x):
if isinstance(x, (tuple, list)):
return tuple(unmemorize(i) for i in x)
else:
return x.read()
class FunctionBlock(Function):
def on_first_call(self, wrapped_function):
old_block = get_tape().active_basicblock
parent_node = get_tape().req_node
get_tape().open_scope(lambda x: x[0], None, 'begin-' + self.name)
block = get_tape().active_basicblock
block.alloc_pool = defaultdict(list)
del parent_node.children[-1]
self.node = get_tape().req_node
if get_program().verbose:
print('Compiling function', self.name)
result = wrapped_function(*self.compile_args)
if result is not None:
self.result = memorize(result)
else:
self.result = None
if get_program().verbose:
print('Done compiling function', self.name)
p_return_address = get_tape().program.malloc(1, 'ci')
get_tape().function_basicblocks[block] = p_return_address
return_address = regint.load_mem(p_return_address)
get_tape().active_basicblock.set_exit(instructions.jmpi(return_address, add_to_prog=False))
self.last_sub_block = get_tape().active_basicblock
get_tape().close_scope(old_block, parent_node, 'end-' + self.name)
old_block.set_exit(instructions.jmp(0, add_to_prog=False), get_tape().active_basicblock)
self.basic_block = block
def on_call(self, base, bases):
if base is not None:
instructions.starg(regint(base))
block = self.basic_block
if block not in get_tape().function_basicblocks:
raise CompilerError('unknown function')
old_block = get_tape().active_basicblock
old_block.set_exit(instructions.jmp(0, add_to_prog=False), block)
p_return_address = get_tape().function_basicblocks[block]
return_address = get_tape().new_reg('ci')
old_block.return_address_store = instructions.ldint(return_address, 0)
instructions.stmint(return_address, p_return_address)
get_tape().start_new_basicblock(name='call-' + self.name)
get_tape().active_basicblock.set_return(old_block, self.last_sub_block)
get_tape().req_node.children.append(self.node)
if self.result is not None:
return unmemorize(self.result)
def function_block(function):
return FunctionBlock(function)
def function_block_with_compile_args(*args):
def wrapper(function):
return FunctionBlock(function, compile_args=args)
return wrapper
def method_block(function):
# If you use this, make sure to use MemValue for all member
# variables.
compiled_functions = {}
def wrapper(self, *args):
if self in compiled_functions:
return compiled_functions[self](*args)
else:
name = '%s-%s' % (type(self).__name__, function.__name__)
block = FunctionBlock(function, name=name, compile_args=(self,))
compiled_functions[self] = block
return block(*args)
return wrapper
def cond_swap(x,y):
b = x < y
if isinstance(x, sfloat):
res = ([], [])
for i,j in enumerate(('v','p','z','s')):
xx = x.__getattribute__(j)
yy = y.__getattribute__(j)
bx = b * xx
by = b * yy
res[0].append(bx + yy - by)
res[1].append(xx - bx + by)
return sfloat(*res[0]), sfloat(*res[1])
bx = b * x
by = b * y
return bx + y - by, x - bx + by
def sort(a):
res = a
for i in range(len(a)):
for j in reversed(list(range(i))):
res[j], res[j+1] = cond_swap(res[j], res[j+1])
return res
def odd_even_merge(a):
if len(a) == 2:
a[0], a[1] = cond_swap(a[0], a[1])
else:
even = a[::2]
odd = a[1::2]
odd_even_merge(even)
odd_even_merge(odd)
a[0] = even[0]
for i in range(1, len(a) // 2):
a[2*i-1], a[2*i] = cond_swap(odd[i-1], even[i])
a[-1] = odd[-1]
def odd_even_merge_sort(a):
if len(a) == 1:
return
elif len(a) % 2 == 0:
lower = a[:len(a)//2]
upper = a[len(a)//2:]
odd_even_merge_sort(lower)
odd_even_merge_sort(upper)
a[:] = lower + upper
odd_even_merge(a)
else:
raise CompilerError('Length of list must be power of two')
def chunky_odd_even_merge_sort(a):
tmp = a[0].Array(len(a))
for i,j in enumerate(a):
tmp[i] = j
l = 1
while l < len(a):
l *= 2
k = 1
while k < l:
k *= 2
def round():
for i in range(len(a)):
a[i] = tmp[i]
for i in range(len(a) // l):
for j in range(l // k):
base = i * l + j
step = l // k
if k == 2:
a[base], a[base+step] = cond_swap(a[base], a[base+step])
else:
b = a[base:base+k*step:step]
for m in range(base + step, base + (k - 1) * step, 2 * step):
a[m], a[m+step] = cond_swap(a[m], a[m+step])
for i in range(len(a)):
tmp[i] = a[i]
chunk = MPCThread(round, 'sort-%d-%d' % (l,k), single_thread=True)
chunk.start()
chunk.join()
#round()
for i in range(len(a)):
a[i] = tmp[i]
def chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7, use_chunk_wraps=False):
if n is None:
n = len(a)
a_base = instructions.program.malloc(n, 's')
for i,j in enumerate(a):
store_in_mem(j, a_base + i)
else:
a_base = a
tmp_base = instructions.program.malloc(n, 's')
chunks = {}
threads = []
def run_threads():
for thread in threads:
thread.start()
for thread in threads:
thread.join()
del threads[:]
def run_chunk(size, base):
if size not in chunks:
def swap_list(list_base):
for i in range(size // 2):
base = list_base + 2 * i
x, y = cond_swap(sint.load_mem(base),
sint.load_mem(base + 1))
store_in_mem(x, base)
store_in_mem(y, base + 1)
chunks[size] = FunctionTape(swap_list, 'sort-%d' % size)
return chunks[size](base)
def run_round(size):
# minimize number of chunk sizes
n_chunks = int(math.ceil(1.0 * size / max_chunk_size))
lower_size = size // n_chunks // 2 * 2
n_lower_size = n_chunks - (size - n_chunks * lower_size) // 2
# print len(to_swap) == lower_size * n_lower_size + \
# (lower_size + 2) * (n_chunks - n_lower_size), \
# len(to_swap), n_chunks, lower_size, n_lower_size
base = 0
round_threads = []
for i in range(n_lower_size):
round_threads.append(run_chunk(lower_size, tmp_base + base))
base += lower_size
for i in range(n_chunks - n_lower_size):
round_threads.append(run_chunk(lower_size + 2, tmp_base + base))
base += lower_size + 2
run_threads_in_rounds(round_threads)
postproc_chunks = []
wrap_chunks = {}
post_threads = []
pre_threads = []
def load_and_store(x, y, to_right):
if to_right:
store_in_mem(sint.load_mem(x), y)
else:
store_in_mem(sint.load_mem(y), x)
def run_setup(k, a_addr, step, tmp_addr):
if k == 2:
def mem_op(preproc, a_addr, step, tmp_addr):
load_and_store(a_addr, tmp_addr, preproc)
load_and_store(a_addr + step, tmp_addr + 1, preproc)
res = 2
else:
def mem_op(preproc, a_addr, step, tmp_addr):
instructions.program.curr_tape.merge_opens = False
# for i,m in enumerate(range(a_addr + step, a_addr + (k - 1) * step, step)):
for i in range(k - 2):
m = a_addr + step + i * step
load_and_store(m, tmp_addr + i, preproc)
res = k - 2
if not use_chunk_wraps or k <= 4:
mem_op(True, a_addr, step, tmp_addr)
postproc_chunks.append((mem_op, (a_addr, step, tmp_addr)))
else:
if k not in wrap_chunks:
pre_chunk = FunctionTape(mem_op, 'pre-%d' % k,
compile_args=[True])
post_chunk = FunctionTape(mem_op, 'post-%d' % k,
compile_args=[False])
wrap_chunks[k] = (pre_chunk, post_chunk)
pre_chunk, post_chunk = wrap_chunks[k]
pre_threads.append(pre_chunk(a_addr, step, tmp_addr))
post_threads.append(post_chunk(a_addr, step, tmp_addr))
return res
def run_threads_in_rounds(all_threads):
for thread in all_threads:
if len(threads) == n_threads:
run_threads()
threads.append(thread)
run_threads()
del all_threads[:]
def run_postproc():
run_threads_in_rounds(post_threads)
for chunk,args in postproc_chunks:
chunk(False, *args)
postproc_chunks[:] = []
l = 1
while l < n:
l *= 2
k = 1
while k < l:
k *= 2
size = 0
instructions.program.curr_tape.merge_opens = False
for i in range(n // l):
for j in range(l // k):
base = i * l + j
step = l // k
size += run_setup(k, a_base + base, step, tmp_base + size)
run_threads_in_rounds(pre_threads)
run_round(size)
run_postproc()
if isinstance(a, list):
for i in range(n):
a[i] = sint.load_mem(a_base + i)
instructions.program.free(a_base, 's')
instructions.program.free(tmp_base, 's')
def loopy_chunkier_odd_even_merge_sort(a, n=None, max_chunk_size=512, n_threads=7):
if n is None:
n = len(a)
a_base = instructions.program.malloc(n, 's')
for i,j in enumerate(a):
store_in_mem(j, a_base + i)
else:
a_base = a
tmp_base = instructions.program.malloc(n, 's')
tmp_i = instructions.program.malloc(1, 'ci')
chunks = {}
threads = []
def run_threads():
for thread in threads:
thread.start()
for thread in threads:
thread.join()
del threads[:]
def run_threads_in_rounds(all_threads):
for thread in all_threads:
if len(threads) == n_threads:
run_threads()
threads.append(thread)
run_threads()
del all_threads[:]
def run_chunk(size, base):
if size not in chunks:
def swap_list(list_base):
for i in range(size // 2):
base = list_base + 2 * i
x, y = cond_swap(sint.load_mem(base),
sint.load_mem(base + 1))
store_in_mem(x, base)
store_in_mem(y, base + 1)
chunks[size] = FunctionTape(swap_list, 'sort-%d' % size)
return chunks[size](base)
def run_round(size):
# minimize number of chunk sizes
n_chunks = int(math.ceil(1.0 * size / max_chunk_size))
lower_size = size // n_chunks // 2 * 2
n_lower_size = n_chunks - (size - n_chunks * lower_size) // 2
# print len(to_swap) == lower_size * n_lower_size + \
# (lower_size + 2) * (n_chunks - n_lower_size), \
# len(to_swap), n_chunks, lower_size, n_lower_size
base = 0
round_threads = []
for i in range(n_lower_size):
round_threads.append(run_chunk(lower_size, tmp_base + base))
base += lower_size
for i in range(n_chunks - n_lower_size):
round_threads.append(run_chunk(lower_size + 2, tmp_base + base))
base += lower_size + 2
run_threads_in_rounds(round_threads)
l = 1
while l < n:
l *= 2
k = 1
while k < l:
k *= 2
def load_and_store(x, y):
if to_tmp:
store_in_mem(sint.load_mem(x), y)
else:
store_in_mem(sint.load_mem(y), x)
def outer(i):
def inner(j):
base = j + a_base + i * l
step = l // k
if k == 2:
tmp_addr = regint.load_mem(tmp_i)
load_and_store(base, tmp_addr)
load_and_store(base + step, tmp_addr + 1)
store_in_mem(tmp_addr + 2, tmp_i)
else:
def inner2(m):
m += base
tmp_addr = regint.load_mem(tmp_i)
load_and_store(m, tmp_addr)
store_in_mem(tmp_addr + 1, tmp_i)
range_loop(inner2, step, (k - 1) * step, step)
range_loop(inner, l // k)
instructions.program.curr_tape.merge_opens = False
to_tmp = True
store_in_mem(tmp_base, tmp_i)
range_loop(outer, n // l)
if k == 2:
run_round(n)
else:
run_round(n // k * (k - 2))
instructions.program.curr_tape.merge_opens = False
to_tmp = False
store_in_mem(tmp_base, tmp_i)
range_loop(outer, n // l)
if isinstance(a, list):
for i in range(n):
a[i] = sint.load_mem(a_base + i)
instructions.program.free(a_base, 's')
instructions.program.free(tmp_base, 's')
instructions.program.free(tmp_i, 'ci')
def loopy_odd_even_merge_sort(a, sorted_length=1, n_parallel=32,
n_threads=None):
steps = {}
l = sorted_length
while l < len(a):
l *= 2
k = 1
while k < l:
k *= 2
n_innermost = 1 if k == 2 else k // 2 - 1
key = k
if key not in steps:
@function_block
def step(l):
l = MemValue(l)
m = 2 ** int(math.ceil(math.log(len(a), 2)))
@for_range_opt_multithread(n_threads, m // k)
def _(i):
n_inner = l // k
j = i % n_inner
i //= n_inner
base = i*l + j
step = l//k
def swap(base, step):
if m == len(a):
a[base], a[base + step] = \
cond_swap(a[base], a[base + step])
else:
# ignore values outside range
go = base + step < len(a)
x = a.maybe_get(go, base)
y = a.maybe_get(go, base + step)
tmp = cond_swap(x, y)
for i, idx in enumerate((base, base + step)):
a.maybe_set(go, idx, tmp[i])
if k == 2:
swap(base, step)
else:
@for_range_opt(n_innermost)
def f(i):
m1 = step + i * 2 * step
m2 = m1 + base
swap(m2, step)
steps[key] = step
steps[key](l)
def mergesort(A):
B = Array(len(A), sint)
def merge(i_left, i_right, i_end):
i0 = MemValue(i_left)
i1 = MemValue(i_right)
@for_range(i_left, i_end)
def loop(j):
if_then(and_(lambda: i0 < i_right,
or_(lambda: i1 >= i_end,
lambda: regint(reveal(A[i0] <= A[i1])))))
B[j] = A[i0]
i0.iadd(1)
else_then()
B[j] = A[i1]
i1.iadd(1)
end_if()
width = MemValue(1)
@do_while
def width_loop():
@for_range(0, len(A), 2 * width)
def merge_loop(i):
merge(i, i + width, i + 2 * width)
A.assign(B)
width.imul(2)
return width < len(A)
def range_loop(loop_body, start, stop=None, step=None):
if stop is None:
stop = start
start = 0
if step is None:
step = 1
def loop_fn(i):
res = loop_body(i)
return util.if_else(res == 0, stop, i + step)
if isinstance(step, int):
if step > 0:
condition = lambda x: x < stop
elif step < 0:
condition = lambda x: x > stop
else:
raise CompilerError('step must not be zero')
else:
b = step > 0
condition = lambda x: b * (x < stop) + (1 - b) * (x > stop)
while_loop(loop_fn, condition, start, g=loop_body.__globals__)
if isinstance(start, int) and isinstance(stop, int) \
and isinstance(step, int):
# known loop count
if condition(start):
get_tape().req_node.children[-1].aggregator = \
lambda x: ((stop - start) // step) * x[0]
def for_range(start, stop=None, step=None):
"""
Decorator to execute loop bodies consecutively. Arguments work as
in Python :py:func:`range`, but they can by any public
integer. Information has to be passed out via container types such
as :py:class:`~Compiler.types.Array` or declaring registers as
:py:obj:`global`. Note that changing Python data structures such
as lists within the loop is not possible, but the compiler cannot
warn about this.
:param start/stop/step: regint/cint/int
Example:
.. code::
a = sint.Array(n)
x = sint(0)
@for_range(n)
def _(i):
a[i] = i
global x
x += 1
"""
def decorator(loop_body):
range_loop(loop_body, start, stop, step)
return loop_body
return decorator
def for_range_parallel(n_parallel, n_loops):
"""
Decorator to execute a loop :py:obj:`n_loops` up to
:py:obj:`n_parallel` loop bodies in parallel.
Using any other control flow instruction inside the loop breaks
the optimization.
:param n_parallel: compile-time (int)
:param n_loops: regint/cint/int
Example:
.. code::
@for_range_parallel(n_parallel, n_loops)
def _(i):
a[i] = a[i] * a[i]
"""
return map_reduce_single(n_parallel, n_loops)
def for_range_opt(n_loops, budget=None):
""" Execute loop bodies in parallel up to an optimization budget.
This prevents excessive loop unrolling. The budget is respected
even with nested loops. Note that the optimization is rather
rudimentary for runtime :py:obj:`n_loops` (regint/cint). Consider
using :py:func:`for_range_parallel` in this case.
Using further control flow constructions inside other than
:py:func:`for_range_opt` (e.g, :py:func:`for_range`) breaks the
optimization.
:param n_loops: int/regint/cint
:param budget: number of instructions after which to start optimization (default is 100,000)
Example:
.. code::
@for_range_opt(n)
def _(i):
...
"""
return map_reduce_single(None, n_loops, budget=budget)
def map_reduce_single(n_parallel, n_loops, initializer=lambda *x: [],
reducer=lambda *x: [], mem_state=None, budget=None):
budget = budget or get_program().budget
if not (isinstance(n_parallel, int) or n_parallel is None):
raise CompilerError('Number of parallel executions must be constant')
n_parallel = 1 if is_zero(n_parallel) else n_parallel
if mem_state is None:
# default to list of MemValues to allow varying types
mem_state = [MemValue(x) for x in initializer()]
use_array = False
else:
# use Arrays for multithread version
use_array = True
if not util.is_constant(n_loops):
budget //= 10
def decorator(loop_body):
my_n_parallel = n_parallel
if isinstance(n_parallel, int):
if isinstance(n_loops, int):
loop_rounds = n_loops // n_parallel \
if n_parallel < n_loops else 0
else:
loop_rounds = n_loops / n_parallel
def write_state_to_memory(r):
if use_array:
mem_state.assign(r)
else:
# cannot do mem_state = [...] due to scope issue
for j,x in enumerate(r):
mem_state[j].write(x)
if n_parallel is not None:
# will be optimized out if n_loops <= n_parallel
@for_range(loop_rounds)
def f(i):
state = tuplify(initializer())
for k in range(n_parallel):
j = i * n_parallel + k
state = reducer(tuplify(loop_body(j)), state)
r = reducer(mem_state, state)
write_state_to_memory(r)
else:
if is_zero(n_loops):
return
n_opt_loops_reg = regint(0)
n_opt_loops_inst = get_block().instructions[-1]
parent_block = get_block()
@while_do(lambda x: x + n_opt_loops_reg <= n_loops, regint(0))
def _(i):
state = tuplify(initializer())
k = 0
block = get_block()
while (not util.is_constant(n_loops) or k < n_loops) \
and (len(get_block()) < budget or k == 0) \
and block is get_block():
j = i + k
state = reducer(tuplify(loop_body(j)), state)
k += 1
r = reducer(mem_state, state)
write_state_to_memory(r)
global n_opt_loops
n_opt_loops = k
n_opt_loops_inst.args[1] = k
return i + k
my_n_parallel = n_opt_loops
loop_rounds = n_loops // my_n_parallel
blocks = get_tape().basicblocks
n_to_merge = 5
if util.is_one(loop_rounds) and parent_block is blocks[-n_to_merge]:
# merge blocks started by if and do_while
def exit_elimination(block):
if block.exit_condition is not None:
for reg in block.exit_condition.get_used():
reg.can_eliminate = True
exit_elimination(parent_block)
merged = parent_block
merged.exit_condition = blocks[-1].exit_condition
merged.exit_block = blocks[-1].exit_block
assert parent_block is blocks[-n_to_merge]
assert blocks[-n_to_merge + 1] is \
get_tape().req_node.children[-1].nodes[0].blocks[0]
for block in blocks[-n_to_merge + 1:]:
merged.instructions += block.instructions
exit_elimination(block)
block.purge(retain_usage=False)
del blocks[-n_to_merge + 1:]
del get_tape().req_node.children[-1]
merged.children = []
RegintOptimizer().run(merged.instructions)
get_tape().active_basicblock = merged
else:
req_node = get_tape().req_node.children[-1].nodes[0]
if util.is_constant(loop_rounds):
req_node.children[0].aggregator = lambda x: loop_rounds * x[0]
if isinstance(n_loops, int):
state = mem_state
for j in range(loop_rounds * my_n_parallel, n_loops):
state = reducer(tuplify(loop_body(j)), state)
else:
@for_range(loop_rounds * my_n_parallel, n_loops)
def f(j):
r = reducer(tuplify(loop_body(j)), mem_state)
write_state_to_memory(r)
state = mem_state
for i,x in enumerate(state):
if use_array:
mem_state[i] = x
else:
mem_state[i].write(x)
def returner():
return untuplify(tuple(state))
return returner
return decorator
def for_range_multithread(n_threads, n_parallel, n_loops, thread_mem_req={}):
"""
Execute :py:obj:`n_loops` loop bodies in up to :py:obj:`n_threads`
threads, up to :py:obj:`n_parallel` in parallel per thread.
:param n_threads/n_parallel: compile-time (int)
:param n_loops: regint/cint/int
"""
return map_reduce(n_threads, n_parallel, n_loops, \
lambda *x: [], lambda *x: [], thread_mem_req)
def for_range_opt_multithread(n_threads, n_loops):
"""
Execute :py:obj:`n_loops` loop bodies in up to :py:obj:`n_threads`
threads, in parallel up to an optimization budget per thread
similar to :py:func:`for_range_opt`. Note that optimization is rather
rudimentary for runtime :py:obj:`n_loops` (regint/cint). Consider
using :py:func:`for_range_multithread` in this case.
:param n_threads: compile-time (int)
:param n_loops: regint/cint/int
The following will execute loop bodies 0-9 in one thread, 10-19 in
another etc:
.. code::
@for_range_opt_multithread(8, 80)
def _(i):
...
Multidimensional ranges are supported as well. The following
executes ``f(0, 0)`` to ``f(2, 0)`` in one thread and ``f(2, 1)``
to ``f(4, 2)`` in another.
.. code::
@for_range_opt_multithread(2, [5, 3])
def f(i, j):
...
"""
return for_range_multithread(n_threads, None, n_loops)
def multithread(n_threads, n_items=None, max_size=None):
"""
Distribute the computation of :py:obj:`n_items` to
:py:obj:`n_threads` threads, but leave the in-thread repetition up
to the user.
:param n_threads: compile-time (int)
:param n_items: regint/cint/int (default: :py:obj:`n_threads`)
The following executes ``f(0, 8)``, ``f(8, 8)``, and
``f(16, 9)`` in three different threads:
.. code::
@multithread(8, 25)
def f(base, size):
...
"""
if n_items is None:
n_items = n_threads
if max_size is None or n_items <= max_size:
return map_reduce(n_threads, None, n_items, initializer=lambda: [],
reducer=None, looping=False)
else:
def wrapper(function):
@multithread(n_threads, n_items)
def new_function(base, size):
@for_range(size // max_size)
def _(i):
function(base + i * max_size, max_size)
rem = size % max_size
if rem:
function(base + size - rem, rem)
return wrapper
def map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, \
thread_mem_req={}, looping=True):
assert(n_threads != 0)
if isinstance(n_loops, (list, tuple)):
split = n_loops
n_loops = reduce(operator.mul, n_loops)
def decorator(loop_body):
def new_body(i):
indices = []
for n in reversed(split):
indices.insert(0, i % n)
i //= n
return loop_body(*indices)
return new_body
new_dec = map_reduce(n_threads, n_parallel, n_loops, initializer, reducer, thread_mem_req)
return lambda loop_body: new_dec(decorator(loop_body))
n_loops = MemValue.if_necessary(n_loops)
if n_threads == None or util.is_one(n_loops):
if not looping:
return lambda loop_body: loop_body(0, n_loops)
dec = map_reduce_single(n_parallel, n_loops, initializer, reducer)
if thread_mem_req:
thread_mem = Array(thread_mem_req[regint], regint)
return lambda loop_body: dec(lambda i: loop_body(i, thread_mem))
else:
return dec
def decorator(loop_body):
thread_rounds = MemValue.if_necessary(n_loops // n_threads)
if util.is_constant(thread_rounds):
remainder = n_loops % n_threads
else:
remainder = 0
for t in thread_mem_req:
if t != regint:
raise CompilerError('Not implemented for other than regint')
args = Matrix(n_threads, 2 + thread_mem_req.get(regint, 0), 'ci')
state = tuple(initializer())
def f(inc):
base = args[get_arg()][0]
if not util.is_constant(thread_rounds):
i = base / thread_rounds
overhang = n_loops % n_threads
inc = i < overhang
base += inc.if_else(i, overhang)
if not looping:
return loop_body(base, thread_rounds + inc)
if thread_mem_req:
thread_mem = Array(thread_mem_req[regint], regint, \
args[get_arg()].address + 2)
mem_state = Array(len(state), type(state[0]) \
if state else cint, args[get_arg()][1])
@map_reduce_single(n_parallel, thread_rounds + inc, \
initializer, reducer, mem_state)
def f(i):
if thread_mem_req:
return loop_body(base + i, thread_mem)
else:
return loop_body(base + i)
prog = get_program()
thread_args = []
if not util.is_zero(thread_rounds):
tape = prog.new_tape(f, (0,), 'multithread')
for i in range(n_threads - remainder):
mem_state = make_array(initializer())
args[remainder + i][0] = i * thread_rounds
if len(mem_state):
args[remainder + i][1] = mem_state.address
thread_args.append((tape, remainder + i))
if remainder:
tape1 = prog.new_tape(f, (1,), 'multithread1')
for i in range(remainder):
mem_state = make_array(initializer())
args[i][0] = (n_threads - remainder + i) * thread_rounds + i
if len(mem_state):
args[i][1] = mem_state.address
thread_args.append((tape1, i))
threads = prog.run_tapes(thread_args)
for thread in threads:
prog.join_tape(thread)
if state:
if thread_rounds:
for i in range(n_threads - remainder):
state = reducer(Array(len(state), type(state[0]), \
args[remainder + i][1]), state)
if remainder:
for i in range(remainder):
state = reducer(Array(len(state), type(state[0]).reg_type, \
args[i][1]), state)
def returner():
return untuplify(state)
return returner
return decorator
def map_sum(n_threads, n_parallel, n_loops, n_items, value_types):
value_types = tuplify(value_types)
if len(value_types) == 1:
value_types *= n_items
elif len(value_types) != n_items:
raise CompilerError('Incorrect number of value_types.')
initializer = lambda: [t(0) for t in value_types]
def summer(x,y):
return tuple(a + b for a,b in zip(x,y))
return map_reduce(n_threads, n_parallel, n_loops, initializer, summer)
def tree_reduce_multithread(n_threads, function, vector):
inputs = vector.Array(len(vector))
inputs.assign_vector(vector)
outputs = vector.Array(len(vector) // 2)
left = len(vector)
while left > 1:
@multithread(n_threads, left // 2)
def _(base, size):
outputs.assign_vector(
function(inputs.get_vector(2 * base, size),
inputs.get_vector(2 * base + size, size)), base)
inputs.assign_vector(outputs.get_vector(0, left // 2))
if left % 2 == 1:
inputs[left // 2] = inputs[left - 1]
left = (left + 1) // 2
return inputs[0]
def foreach_enumerate(a):
""" Run-time loop over public data. This uses
``Player-Data/Public-Input/<progname>``. Example:
.. code::
@foreach_enumerate([2, 8, 3])
def _(i, j):
print_ln('%s: %s', i, j)
This will output:
.. code::
0: 2
1: 8
2: 3
"""
for x in a:
get_program().public_input(' '.join(str(y) for y in tuplify(x)))
def decorator(loop_body):
@for_range(len(a))
def f(i):
loop_body(i, *(public_input() for j in range(len(tuplify(a[0])))))
return f
return decorator
def while_loop(loop_body, condition, arg, g=None):
if not callable(condition):
raise CompilerError('Condition must be callable')
# store arg in stack
pre_condition = condition(arg)
if not isinstance(pre_condition, (bool,int)) or pre_condition:
arg = regint(arg)
def loop_fn():
result = loop_body(arg)
result.link(arg)
cont = condition(result)
return cont
if_statement(pre_condition, lambda: do_while(loop_fn, g=g))
def while_do(condition, *args):
""" While-do loop. The decorator requires an initialization, and
the loop body function must return a suitable input for
:py:obj:`condition`.
:param condition: function returning public integer (regint/cint/int)
:param args: arguments given to :py:obj:`condition` and loop body
The following executes an ten-fold loop:
.. code::
@while_do(lambda x: x < 10, regint(0))
def f(i):
...
return i + 1
"""
def decorator(loop_body):
while_loop(loop_body, condition, *args)
return loop_body
return decorator
def do_loop(condition, loop_fn):
# store initial condition to stack
pushint(condition if isinstance(condition,regint) else regint(condition))
def wrapped_loop():
# save condition to stack
new_cond = regint.pop()
# run the loop
condition = loop_fn(new_cond)
pushint(condition)
return condition
do_while(wrapped_loop)
regint.pop()
def _run_and_link(function, g=None):
if g is None:
g = function.__globals__
import copy
pre = copy.copy(g)
res = function()
if g:
from .types import _single
for name, var in pre.items():
if isinstance(var, (program.Tape.Register, _single)):
new_var = g[name]
if id(new_var) != id(var):
new_var.link(var)
return res
def do_while(loop_fn, g=None):
""" Do-while loop. The loop is stopped if the return value is zero.
It must be public. The following executes exactly once:
.. code::
@do_while
def _():
...
return regint(0)
"""
scope = instructions.program.curr_block
parent_node = get_tape().req_node
# possibly unknown loop count
get_tape().open_scope(lambda x: x[0].set_all(float('Inf')), \
name='begin-loop')
loop_block = instructions.program.curr_block
condition = _run_and_link(loop_fn, g)
if callable(condition):
condition = condition()
branch = instructions.jmpnz(regint.conv(condition), 0, add_to_prog=False)
instructions.program.curr_block.set_exit(branch, loop_block)
get_tape().close_scope(scope, parent_node, 'end-loop')
return loop_fn
def if_then(condition):
class State: pass
state = State()
if callable(condition):
condition = condition()
state.condition = regint.conv(condition)
state.start_block = instructions.program.curr_block
state.req_child = get_tape().open_scope(lambda x: x[0].max(x[1]), \
name='if-block')
state.has_else = False
instructions.program.curr_tape.if_states.append(state)
def else_then():
try:
state = instructions.program.curr_tape.if_states[-1]
except IndexError:
raise CompilerError('No open if block')
if state.has_else:
raise CompilerError('else block already defined')
# run the else block
state.if_exit_block = instructions.program.curr_block
state.req_child.add_node(get_tape(), 'else-block')
instructions.program.curr_tape.start_new_basicblock(state.start_block, \
name='else-block')
state.else_block = instructions.program.curr_block
state.has_else = True
def end_if():
try:
state = instructions.program.curr_tape.if_states.pop()
except IndexError:
raise CompilerError('No open if/else block')
branch = instructions.jmpeqz(regint.conv(state.condition), 0, \
add_to_prog=False)
# start next block
get_tape().close_scope(state.start_block, state.req_child.parent, 'end-if')
if state.has_else:
# jump to else block if condition == 0
state.start_block.set_exit(branch, state.else_block)
# set if block to skip else
jump = instructions.jmp(0, add_to_prog=False)
state.if_exit_block.set_exit(jump, instructions.program.curr_block)
else:
# set start block's conditional jump to next block
state.start_block.set_exit(branch, instructions.program.curr_block)
# nothing to compute without else
state.req_child.aggregator = lambda x: x[0]
def if_statement(condition, if_fn, else_fn=None):
if condition is True or condition is False:
# condition known at compile time
if condition:
if_fn()
elif else_fn is not None:
else_fn()
else:
state = if_then(condition)
if_fn()
if else_fn is not None:
else_then()
else_fn()
end_if()
def if_(condition):
"""
Conditional execution without else block.
:param condition: regint/cint/int
Usage:
.. code::
@if_(x > 0)
def _():
...
"""
def decorator(body):
if_then(condition)
_run_and_link(body)
end_if()
return decorator
def if_e(condition):
"""
Conditional execution with else block.
:param condition: regint/cint/int
Usage:
.. code::
@if_e(x > 0)
def _():
...
@else_
def _():
...
"""
def decorator(body):
if_then(condition)
_run_and_link(body)
return decorator
def else_(body):
else_then()
_run_and_link(body)
end_if()
def and_(*terms):
res = regint(0)
for term in terms:
if_then(term())
old_res = res
res = regint(1)
res.link(old_res)
for term in terms:
else_then()
end_if()
def load_result():
return res
return load_result
def or_(*terms):
res = regint(1)
for term in terms:
if_then(term())
else_then()
old_res = res
res = regint(0)
res.link(old_res)
for term in terms:
end_if()
def load_result():
return res
return load_result
def not_(term):
return lambda: 1 - term()
def start_timer(timer_id=0):
""" Start timer. Timer 0 runs from the start of the program. The
total time of all used timers is output at the end. Fails if
already running.
:param timer_id: compile-time (int) """
get_tape().start_new_basicblock(name='pre-start-timer')
start(timer_id)
get_tape().start_new_basicblock(name='post-start-timer')
def stop_timer(timer_id=0):
""" Stop timer. Fails if not running.
:param timer_id: compile-time (int) """
get_tape().start_new_basicblock(name='pre-stop-timer')
stop(timer_id)
get_tape().start_new_basicblock(name='post-stop-timer')
def get_number_of_players():
"""
:return: the number of players
:rtype: regint
"""
res = regint()
nplayers(res)
return res
def get_threshold():
""" The threshold is the maximal number of corrupted
players.
:rtype: regint
"""
res = regint()
threshold(res)
return res
def get_player_id():
"""
:return: player number
:rtype: localint (cannot be used for computation) """
res = localint()
playerid(res._v)
return res
def break_point(name=''):
"""
Insert break point. This makes sure that all following code
will be executed after preceding code.
:param name: Name for identification (optional)
"""
get_tape().start_new_basicblock(name=name)
def check_point():
"""
Force MAC checks in current thread and all idle threads if the
current thread is the main thread. This implies a break point.
"""
break_point('pre-check')
check()
break_point('post-check')
# Fixed point ops
from math import ceil, log
from .floatingpoint import PreOR, TruncPr, two_power
def approximate_reciprocal(divisor, k, f, theta):
"""
returns aproximation of 1/divisor
where type(divisor) = cint
"""
def twos_complement(x):
bits = x.bit_decompose(k)[::-1]
twos_result = cint(0)
for i in range(k):
val = twos_result
val <<= 1
val += 1 - bits[i]
twos_result = val
return twos_result + 1
bits = divisor.bit_decompose(k)[::-1]
flag = regint(0)
cnt_leading_zeros = regint(0)
normalized_divisor = divisor
for i in range(k):
flag = flag | (bits[i] == 1)
flag_zero = cint(flag == 0)
cnt_leading_zeros += flag_zero
normalized_divisor <<= flag_zero
q = two_power(k)
e = twos_complement(normalized_divisor)
for i in range(theta):
q += (q * e) >> k
e = (e * e) >> k
res = q >> cint(2*k - 2*f - cnt_leading_zeros)
return res
def cint_cint_division(a, b, k, f):
"""
Goldschmidt method implemented with
SE aproximation:
http://stackoverflow.com/questions/2661541/picking-good-first-estimates-for-goldschmidt-division
"""
# theta can be replaced with something smaller
# for safety we assume that is the same theta from previous GS method
if get_program().options.ring:
assert 2 * f < int(get_program().options.ring)
theta = int(ceil(log(k/3.5) / log(2)))
two = cint(2) * two_power(f)
sign_b = cint(1) - 2 * cint(b.less_than(0, k))
sign_a = cint(1) - 2 * cint(a.less_than(0, k))
absolute_b = b * sign_b
absolute_a = a * sign_a
w0 = approximate_reciprocal(absolute_b, k, f, theta)
A = absolute_a
B = absolute_b
W = w0
corr = cint(1) << (f - 1)
for i in range(theta):
A = (A * W + corr) >> f
B = (B * W + corr) >> f
W = two - B
return (sign_a * sign_b) * A
from Compiler.program import Program
def sint_cint_division(a, b, k, f, kappa):
"""
type(a) = sint, type(b) = cint
"""
theta = int(ceil(log(k/3.5) / log(2)))
two = cint(2) * two_power(f)
sign_b = cint(1) - 2 * cint(b.less_than(0, k))
sign_a = sint(1) - 2 * comparison.LessThanZero(a, k, kappa)
absolute_b = b * sign_b
absolute_a = a * sign_a
w0 = approximate_reciprocal(absolute_b, k, f, theta)
A = absolute_a
B = absolute_b
W = w0
@for_range(1, theta)
def block(i):
A.link(TruncPr(A * W, 2*k, f, kappa))
temp = (B * W) >> f
W.link(two - temp)
B.link(temp)
return (sign_a * sign_b) * A
def IntDiv(a, b, k, kappa=None):
return FPDiv(a.extend(2 * k) << k, b.extend(2 * k) << k, 2 * k, k,
kappa, nearest=True)
@instructions_base.ret_cisc
def FPDiv(a, b, k, f, kappa, simplex_flag=False, nearest=False):
"""
Goldschmidt method as presented in Catrina10,
"""
prime = get_program().prime
if 2 * k == int(get_program().options.ring) or \
(prime and 2 * k <= (prime.bit_length() - 1)):
# not fitting otherwise
nearest = True
if get_program().options.binary:
# no probabilistic truncation in binary circuits
nearest = True
res_f = f
f = max((k - nearest) // 2 + 1, f)
assert 2 * f > k - nearest
theta = int(ceil(log(k/3.5) / log(2)))
base.set_global_vector_size(b.size)
alpha = b.get_type(2 * k).two_power(2*f)
w = AppRcr(b, k, f, kappa, simplex_flag, nearest).extend(2 * k)
x = alpha - b.extend(2 * k) * w
base.reset_global_vector_size()
y = a.extend(2 *k) * w
y = y.round(2*k, f, kappa, nearest, signed=True)
for i in range(theta - 1):
x = x.extend(2 * k)
y = y.extend(2 * k) * (alpha + x).extend(2 * k)
x = x * x
y = y.round(2*k, 2*f, kappa, nearest, signed=True)
x = x.round(2*k, 2*f, kappa, nearest, signed=True)
x = x.extend(2 * k)
y = y.extend(2 * k) * (alpha + x).extend(2 * k)
y = y.round(k + 3 * f - res_f, 3 * f - res_f, kappa, nearest, signed=True)
return y
def AppRcr(b, k, f, kappa=None, simplex_flag=False, nearest=False):
"""
Approximate reciprocal of [b]:
Given [b], compute [1/b]
"""
alpha = b.get_type(2 * k)(int(2.9142 * 2**k))
c, v = b.Norm(k, f, kappa, simplex_flag)
#v should be 2**{k - m} where m is the length of the bitwise repr of [b]
d = alpha - 2 * c
w = d * v
w = w.round(2 * k + 1, 2 * (k - f), kappa, nearest, signed=True)
# now w * 2 ^ {-f} should be an initial approximation of 1/b
return w
def Norm(b, k, f, kappa, simplex_flag=False):
"""
Computes secret integer values [c] and [v_prime] st.
2^{k-1} <= c < 2^k and c = b*v_prime
"""
# For simplex, we can get rid of computing abs(b)
temp = None
if simplex_flag == False:
temp = comparison.LessThanZero(b, k, kappa)
elif simplex_flag == True:
temp = cint(0)
sign = 1 - 2 * temp # 1 - 2 * [b < 0]
absolute_val = sign * b
#next 2 lines actually compute the SufOR for little indian encoding
bits = absolute_val.bit_decompose(k, kappa, maybe_mixed=True)[::-1]
suffixes = PreOR(bits, kappa)[::-1]
z = [0] * k
for i in range(k - 1):
z[i] = suffixes[i] - suffixes[i+1]
z[k - 1] = suffixes[k-1]
acc = sint.bit_compose(reversed(z))
part_reciprocal = absolute_val * acc
signed_acc = sign * acc
return part_reciprocal, signed_acc