Fix summation of binary vectors.

This commit is contained in:
Marcel Keller
2025-09-17 16:18:24 +08:00
parent 1091a2cf4c
commit e18f875f9c
3 changed files with 11 additions and 4 deletions

View File

@@ -792,6 +792,8 @@ class sbitvec(_vec, _bit, _binary):
@classmethod
def from_vec(cls, vector):
res = cls()
if isinstance(vector, sbitvec):
vector = vector.v
res.v = _complement_two_extend(list(vector), n)[:n]
return res
def __init__(self, other=None, size=None):
@@ -867,6 +869,8 @@ class sbitvec(_vec, _bit, _binary):
return sbitvecn
@classmethod
def from_vec(cls, vector):
if isinstance(vector, sbitvec):
vector = vector.v
res = cls()
res.v = list(vector)
return res
@@ -954,7 +958,7 @@ class sbitvec(_vec, _bit, _binary):
def if_else(self, x, y):
return util.if_else(self.v[0], x, y)
def __iter__(self):
return iter(self.v)
return iter(self.elements())
def __len__(self):
return len(self.v)
def __getitem__(self, index):
@@ -1423,7 +1427,7 @@ class sbitintvec(sbitvec, _bitint, _number, _sbitintbase):
def instruction(*args):
res = self.binary_mul(args[bl:2 * bl], args[2 * bl:],
args[0].n)
for x, y in zip(res, args):
for x, y in zip(sbitvec.from_vec(res).v, args):
x.mov(y, x)
instruction.__name__ = 'binary_mul%sx%s' % (bl, len(other_bits))
self.mul_functions[key] = instructions_base.cisc(instruction,

View File

@@ -69,7 +69,8 @@ class Circuit:
f = function_block
self.functions[n] = f(lambda *args: self.compile(*args))
self.functions[n].name = '%s(%d)' % (self.name, inputs[0][0].n)
flat_res = self.functions[n](*itertools.chain(*inputs))
flat_res = self.functions[n](*itertools.chain(*(
sbitvec.from_vec(x).v for x in inputs)))
res = []
i = 0
for l in self.n_output_wires:

View File

@@ -2297,9 +2297,11 @@ class _secret(_arithmetic_register, _secret_structure):
""" Compose value from bits.
:param bits: iterable of any type convertible to sint """
from Compiler.GC.types import sbits, sbitintvec
from Compiler.GC.types import sbits, sbitintvec, sbitvec
if isinstance(bits, sbits):
bits = bits.bit_decompose()
elif isinstance(bits, sbitvec):
bits = bits.v
bits = list(bits)
if (program.use_edabit() or program.use_split()) and isinstance(bits[0], sbits):
if program.use_edabit():