mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-10 05:57:57 -05:00
Vectorized fixed-point multiplication in binary circuits.
This commit is contained in:
@@ -670,6 +670,9 @@ class DynamicArray(Array):
|
||||
sbits.dynamic_array = DynamicArray
|
||||
cbits.dynamic_array = Array
|
||||
|
||||
def _complement_two_extend(bits, k):
|
||||
return bits + [bits[-1]] * (k - len(bits))
|
||||
|
||||
class sbitint(_bitint, _number, sbits):
|
||||
n_bits = None
|
||||
bin_type = None
|
||||
@@ -798,6 +801,17 @@ class sbitintvec(sbitvec, _number):
|
||||
return self.from_vec(v[:len(self.v)])
|
||||
__rmul__ = __mul__
|
||||
reduce_after_mul = lambda x: x
|
||||
def TruncMul(self, other, k, m, kappa=None, nearest=False):
|
||||
if nearest:
|
||||
raise CompilerError('round to nearest not implemented')
|
||||
if not isinstance(other, sbitintvec):
|
||||
other = sbitintvec(other)
|
||||
assert len(self.v) + len(other.v) == k
|
||||
a = self.from_vec(_complement_two_extend(self.v, k))
|
||||
b = self.from_vec(_complement_two_extend(other.v, k))
|
||||
tmp = a * b
|
||||
assert len(tmp.v) == k
|
||||
return self.from_vec(tmp[m:])
|
||||
|
||||
sbitint.vec = sbitintvec
|
||||
|
||||
@@ -834,6 +848,8 @@ class sbitfix(_fix):
|
||||
def __mul__(self, other):
|
||||
if isinstance(other, sbit):
|
||||
return type(self)(self.int_type(other * self.v))
|
||||
elif isinstance(other, sbitfixvec):
|
||||
return other * self
|
||||
else:
|
||||
return super(sbitfix, self).__mul__(other)
|
||||
__rxor__ = __xor__
|
||||
@@ -850,14 +866,29 @@ sbitfix.set_precision(20, 41)
|
||||
class sbitfixvec(_fix):
|
||||
int_type = sbitintvec
|
||||
float_type = type(None)
|
||||
@staticmethod
|
||||
clear_type = type(None)
|
||||
_f = None
|
||||
_k = None
|
||||
@property
|
||||
def f():
|
||||
return sbitfix.f
|
||||
@staticmethod
|
||||
def f(self):
|
||||
if self._f is None:
|
||||
return sbitfix.f
|
||||
else:
|
||||
return self._f
|
||||
@f.setter
|
||||
def f(self, value):
|
||||
self._f = value
|
||||
@property
|
||||
def k():
|
||||
return sbitfix.k
|
||||
def k(self):
|
||||
if self._k is None:
|
||||
return sbitfix.k
|
||||
else:
|
||||
return self._k
|
||||
@k.setter
|
||||
def k(self, value):
|
||||
self._k = value
|
||||
def coerce(self, other):
|
||||
return other
|
||||
|
||||
sbitfix.vec = sbitfixvec
|
||||
|
||||
|
||||
@@ -3254,7 +3254,10 @@ class _fix(_single):
|
||||
v //= 2
|
||||
k = len(bin(abs(v))) - 1
|
||||
other = self.multipliable(v, k, f)
|
||||
other = self.coerce(other)
|
||||
try:
|
||||
other = self.coerce(other)
|
||||
except:
|
||||
return NotImplemented
|
||||
if isinstance(other, (_fix, self.clear_type)):
|
||||
val = self.v.TruncMul(other.v, self.k + other.k, other.f,
|
||||
self.kappa,
|
||||
|
||||
Reference in New Issue
Block a user