Fix bugs in matrix multiplication with binary circuits.

This commit is contained in:
Marcel Keller
2022-07-25 18:12:04 +10:00
parent 101879f37a
commit 81419ba321
4 changed files with 33 additions and 5 deletions

View File

@@ -17,6 +17,7 @@ from Compiler import instructions_base
import Compiler.GC.instructions as inst
import operator
import math
import itertools
from functools import reduce
class bits(Tape.Register, _structure, _bit):
@@ -1182,9 +1183,20 @@ class sbitintvec(sbitvec, _number, _bitint, _sbitintbase):
return self.from_vec(other * x for x in self.v)
elif isinstance(other, sbitfixvec):
return NotImplemented
other_bits = util.bit_decompose(other)
m = float('inf')
for x in itertools.chain(self.v, other_bits):
try:
m = min(m, x.n)
except:
pass
if m == 1:
op = operator.mul
else:
op = operator.and_
matrix = []
for i, b in enumerate(util.bit_decompose(other)):
matrix.append([x & b for x in self.v[:len(self.v)-i]])
for i, b in enumerate(other_bits):
matrix.append([op(x, b) for x in self.v[:len(self.v)-i]])
v = sbitint.wallace_tree_from_matrix(matrix)
return self.from_vec(v[:len(self.v)])
__rmul__ = __mul__

View File

@@ -5263,7 +5263,8 @@ class Array(_vectorizable):
# length can be None for single-element arrays
length = 0
base = self.address + index * self.value_type.mem_size()
if size is not None and isinstance(base, _register):
if size is not None and isinstance(base, _register) \
and not issubclass(self.value_type, _vec):
base = regint._expand_address(base, size)
self.address_cache[program.curr_block, key] = \
util.untuplify([base + i * length \
@@ -6063,6 +6064,7 @@ class SubMultiArray(_vectorizable):
assert n_threads is None
if max(res_matrix.sizes) > 1000:
raise AttributeError()
self.value_type.matrix_mul
A = self.get_vector()
B = other.get_vector()
res_matrix.assign_vector(

View File

@@ -146,7 +146,7 @@ public:
if (this != &res)
res.get_regs().assign(this->get_regs().begin(),
this->get_regs().begin()
+ max(size_t(n_bits), this->get_regs().size()));
+ min(size_t(n_bits), this->get_regs().size()));
res.resize_regs(n_bits);
}

View File

@@ -666,6 +666,21 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
return r[1] + size;
else
return 0;
case TRANS:
if (reg_type == SBIT)
{
int n_outputs = n;
auto& args = start;
int n_inputs = args.size() - n_outputs;
long long res = 0;
for (int i = 0; i < n_outputs; i++)
res = max(res, args[i] + DIV_CEIL(n_inputs, 64));
for (int j = 0; j < n_inputs; j++)
res = max(res, args[n_outputs] + DIV_CEIL(n_outputs, 64));
return res;
}
else
return 0;
default:
if (get_reg_type() != reg_type)
return 0;
@@ -731,7 +746,6 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
case ANDM:
case NOTS:
case NOTCB:
case TRANS:
size = DIV_CEIL(n, 64);
break;
case CONVCBIT2S: