Vectorized fixed-point multiplication in binary circuits.

This commit is contained in:
Marcel Keller
2020-06-24 16:31:28 +10:00
parent 20f84de3b9
commit 898b87a78a
2 changed files with 41 additions and 7 deletions

View File

@@ -670,6 +670,9 @@ class DynamicArray(Array):
sbits.dynamic_array = DynamicArray
cbits.dynamic_array = Array
def _complement_two_extend(bits, k):
return bits + [bits[-1]] * (k - len(bits))
class sbitint(_bitint, _number, sbits):
n_bits = None
bin_type = None
@@ -798,6 +801,17 @@ class sbitintvec(sbitvec, _number):
return self.from_vec(v[:len(self.v)])
__rmul__ = __mul__
reduce_after_mul = lambda x: x
def TruncMul(self, other, k, m, kappa=None, nearest=False):
if nearest:
raise CompilerError('round to nearest not implemented')
if not isinstance(other, sbitintvec):
other = sbitintvec(other)
assert len(self.v) + len(other.v) == k
a = self.from_vec(_complement_two_extend(self.v, k))
b = self.from_vec(_complement_two_extend(other.v, k))
tmp = a * b
assert len(tmp.v) == k
return self.from_vec(tmp[m:])
sbitint.vec = sbitintvec
@@ -834,6 +848,8 @@ class sbitfix(_fix):
def __mul__(self, other):
if isinstance(other, sbit):
return type(self)(self.int_type(other * self.v))
elif isinstance(other, sbitfixvec):
return other * self
else:
return super(sbitfix, self).__mul__(other)
__rxor__ = __xor__
@@ -850,14 +866,29 @@ sbitfix.set_precision(20, 41)
class sbitfixvec(_fix):
int_type = sbitintvec
float_type = type(None)
@staticmethod
clear_type = type(None)
_f = None
_k = None
@property
def f():
return sbitfix.f
@staticmethod
def f(self):
if self._f is None:
return sbitfix.f
else:
return self._f
@f.setter
def f(self, value):
self._f = value
@property
def k():
return sbitfix.k
def k(self):
if self._k is None:
return sbitfix.k
else:
return self._k
@k.setter
def k(self, value):
self._k = value
def coerce(self, other):
return other
sbitfix.vec = sbitfixvec

View File

@@ -3254,7 +3254,10 @@ class _fix(_single):
v //= 2
k = len(bin(abs(v))) - 1
other = self.multipliable(v, k, f)
other = self.coerce(other)
try:
other = self.coerce(other)
except:
return NotImplemented
if isinstance(other, (_fix, self.clear_type)):
val = self.v.TruncMul(other.v, self.k + other.k, other.f,
self.kappa,