mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 13:37:58 -05:00
Improved binary circuit functionality.
This commit is contained in:
@@ -661,6 +661,9 @@ class sbitvec(_vec):
|
||||
return sbit.malloc(size * n, creator_tape=creator_tape)
|
||||
@staticmethod
|
||||
def n_elements():
|
||||
return 1
|
||||
@staticmethod
|
||||
def mem_size():
|
||||
return n
|
||||
@classmethod
|
||||
def get_input_from(cls, player):
|
||||
@@ -692,22 +695,33 @@ class sbitvec(_vec):
|
||||
self.v = sbits.get_type(n)(other).bit_decompose()
|
||||
assert len(self.v) == n
|
||||
@classmethod
|
||||
def load_mem(cls, address):
|
||||
def load_mem(cls, address, size=None):
|
||||
if size not in (None, 1):
|
||||
assert isinstance(address, int) or len(address) == 1
|
||||
sb = sbits.get_type(size)
|
||||
return cls.from_vec(sb.bit_compose(
|
||||
sbit.load_mem(address + i + j * n) for j in range(size))
|
||||
for i in range(n))
|
||||
if not isinstance(address, int) and len(address) == n:
|
||||
return cls.from_vec(sbit.load_mem(x) for x in address)
|
||||
else:
|
||||
return cls.from_vec(sbit.load_mem(address + i)
|
||||
for i in range(n))
|
||||
def store_in_mem(self, address):
|
||||
size = 1
|
||||
for x in self.v:
|
||||
assert util.is_constant(x) or x.n == 1
|
||||
v = [sbit.conv(x) for x in self.v]
|
||||
if not util.is_constant(x):
|
||||
size = max(size, x.n)
|
||||
v = [sbits.get_type(size).conv(x) for x in self.v]
|
||||
if not isinstance(address, int) and len(address) == n:
|
||||
assert max_n == 1
|
||||
for x, y in zip(v, address):
|
||||
x.store_in_mem(y)
|
||||
else:
|
||||
assert isinstance(address, int) or len(address) == 1
|
||||
for i in range(n):
|
||||
v[i].store_in_mem(address + i)
|
||||
for j, x in enumerate(v[i].bit_decompose()):
|
||||
x.store_in_mem(address + i + j * n)
|
||||
def reveal(self):
|
||||
if len(self) > cbits.unit:
|
||||
return self.elements()[0].reveal()
|
||||
@@ -861,6 +875,19 @@ class sbitvec(_vec):
|
||||
return self ^ other
|
||||
def right_shift(self, m, k, security=None, signed=True):
|
||||
return self.from_vec(self.v[m:])
|
||||
def tree_reduce(self, function):
|
||||
elements = self.elements()
|
||||
while len(elements) > 1:
|
||||
size = len(elements)
|
||||
half = size // 2
|
||||
left = elements[:half]
|
||||
right = elements[half:2*half]
|
||||
odd = elements[2*half:]
|
||||
sides = [self.from_vec(sbitvec(x).v) for x in (left, right)]
|
||||
red = function(*sides)
|
||||
elements = red.elements()
|
||||
elements += odd
|
||||
return self.from_vec(sbitvec(elements).v)
|
||||
|
||||
class bit(object):
|
||||
n = 1
|
||||
|
||||
@@ -5701,7 +5701,8 @@ class SubMultiArray(_vectorizable):
|
||||
self.sub_cache[key] = \
|
||||
Array(self.sizes[1], self.value_type, \
|
||||
self.address + index * self.sizes[1] *
|
||||
self.value_type.n_elements(), \
|
||||
self.value_type.n_elements() * \
|
||||
self.value_type.mem_size(), \
|
||||
debug=self.debug)
|
||||
else:
|
||||
self.sub_cache[key] = \
|
||||
|
||||
@@ -116,6 +116,11 @@ def round_to_int(x):
|
||||
return x.round_to_int()
|
||||
|
||||
def tree_reduce(function, sequence):
|
||||
try:
|
||||
return sequence.tree_reduce(function)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
sequence = list(sequence)
|
||||
assert len(sequence) > 0
|
||||
n = len(sequence)
|
||||
|
||||
@@ -730,6 +730,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
case ANDM:
|
||||
case NOTS:
|
||||
case NOTCB:
|
||||
case TRANS:
|
||||
size = DIV_CEIL(n, 64);
|
||||
break;
|
||||
case CONVCBIT2S:
|
||||
|
||||
Reference in New Issue
Block a user