Improved binary circuit functionality.

This commit is contained in:
Marcel Keller
2022-06-14 16:14:37 +02:00
parent 6755a8fa51
commit 4c8e616b58
4 changed files with 39 additions and 5 deletions

View File

@@ -661,6 +661,9 @@ class sbitvec(_vec):
return sbit.malloc(size * n, creator_tape=creator_tape)
@staticmethod
def n_elements():
return 1
@staticmethod
def mem_size():
return n
@classmethod
def get_input_from(cls, player):
@@ -692,22 +695,33 @@ class sbitvec(_vec):
self.v = sbits.get_type(n)(other).bit_decompose()
assert len(self.v) == n
@classmethod
def load_mem(cls, address):
def load_mem(cls, address, size=None):
if size not in (None, 1):
assert isinstance(address, int) or len(address) == 1
sb = sbits.get_type(size)
return cls.from_vec(sb.bit_compose(
sbit.load_mem(address + i + j * n) for j in range(size))
for i in range(n))
if not isinstance(address, int) and len(address) == n:
return cls.from_vec(sbit.load_mem(x) for x in address)
else:
return cls.from_vec(sbit.load_mem(address + i)
for i in range(n))
def store_in_mem(self, address):
size = 1
for x in self.v:
assert util.is_constant(x) or x.n == 1
v = [sbit.conv(x) for x in self.v]
if not util.is_constant(x):
size = max(size, x.n)
v = [sbits.get_type(size).conv(x) for x in self.v]
if not isinstance(address, int) and len(address) == n:
assert max_n == 1
for x, y in zip(v, address):
x.store_in_mem(y)
else:
assert isinstance(address, int) or len(address) == 1
for i in range(n):
v[i].store_in_mem(address + i)
for j, x in enumerate(v[i].bit_decompose()):
x.store_in_mem(address + i + j * n)
def reveal(self):
if len(self) > cbits.unit:
return self.elements()[0].reveal()
@@ -861,6 +875,19 @@ class sbitvec(_vec):
return self ^ other
def right_shift(self, m, k, security=None, signed=True):
return self.from_vec(self.v[m:])
def tree_reduce(self, function):
elements = self.elements()
while len(elements) > 1:
size = len(elements)
half = size // 2
left = elements[:half]
right = elements[half:2*half]
odd = elements[2*half:]
sides = [self.from_vec(sbitvec(x).v) for x in (left, right)]
red = function(*sides)
elements = red.elements()
elements += odd
return self.from_vec(sbitvec(elements).v)
class bit(object):
n = 1

View File

@@ -5701,7 +5701,8 @@ class SubMultiArray(_vectorizable):
self.sub_cache[key] = \
Array(self.sizes[1], self.value_type, \
self.address + index * self.sizes[1] *
self.value_type.n_elements(), \
self.value_type.n_elements() * \
self.value_type.mem_size(), \
debug=self.debug)
else:
self.sub_cache[key] = \

View File

@@ -116,6 +116,11 @@ def round_to_int(x):
return x.round_to_int()
def tree_reduce(function, sequence):
try:
return sequence.tree_reduce(function)
except AttributeError:
pass
sequence = list(sequence)
assert len(sequence) > 0
n = len(sequence)

View File

@@ -730,6 +730,7 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
case ANDM:
case NOTS:
case NOTCB:
case TRANS:
size = DIV_CEIL(n, 64);
break;
case CONVCBIT2S: