mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-08 21:18:03 -05:00
Fix bugs in matrix multiplication with binary circuits.
This commit is contained in:
@@ -17,6 +17,7 @@ from Compiler import instructions_base
|
||||
import Compiler.GC.instructions as inst
|
||||
import operator
|
||||
import math
|
||||
import itertools
|
||||
from functools import reduce
|
||||
|
||||
class bits(Tape.Register, _structure, _bit):
|
||||
@@ -1182,9 +1183,20 @@ class sbitintvec(sbitvec, _number, _bitint, _sbitintbase):
|
||||
return self.from_vec(other * x for x in self.v)
|
||||
elif isinstance(other, sbitfixvec):
|
||||
return NotImplemented
|
||||
other_bits = util.bit_decompose(other)
|
||||
m = float('inf')
|
||||
for x in itertools.chain(self.v, other_bits):
|
||||
try:
|
||||
m = min(m, x.n)
|
||||
except:
|
||||
pass
|
||||
if m == 1:
|
||||
op = operator.mul
|
||||
else:
|
||||
op = operator.and_
|
||||
matrix = []
|
||||
for i, b in enumerate(util.bit_decompose(other)):
|
||||
matrix.append([x & b for x in self.v[:len(self.v)-i]])
|
||||
for i, b in enumerate(other_bits):
|
||||
matrix.append([op(x, b) for x in self.v[:len(self.v)-i]])
|
||||
v = sbitint.wallace_tree_from_matrix(matrix)
|
||||
return self.from_vec(v[:len(self.v)])
|
||||
__rmul__ = __mul__
|
||||
|
||||
@@ -5263,7 +5263,8 @@ class Array(_vectorizable):
|
||||
# length can be None for single-element arrays
|
||||
length = 0
|
||||
base = self.address + index * self.value_type.mem_size()
|
||||
if size is not None and isinstance(base, _register):
|
||||
if size is not None and isinstance(base, _register) \
|
||||
and not issubclass(self.value_type, _vec):
|
||||
base = regint._expand_address(base, size)
|
||||
self.address_cache[program.curr_block, key] = \
|
||||
util.untuplify([base + i * length \
|
||||
@@ -6063,6 +6064,7 @@ class SubMultiArray(_vectorizable):
|
||||
assert n_threads is None
|
||||
if max(res_matrix.sizes) > 1000:
|
||||
raise AttributeError()
|
||||
self.value_type.matrix_mul
|
||||
A = self.get_vector()
|
||||
B = other.get_vector()
|
||||
res_matrix.assign_vector(
|
||||
|
||||
@@ -146,7 +146,7 @@ public:
|
||||
if (this != &res)
|
||||
res.get_regs().assign(this->get_regs().begin(),
|
||||
this->get_regs().begin()
|
||||
+ max(size_t(n_bits), this->get_regs().size()));
|
||||
+ min(size_t(n_bits), this->get_regs().size()));
|
||||
|
||||
res.resize_regs(n_bits);
|
||||
}
|
||||
|
||||
@@ -666,6 +666,21 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
return r[1] + size;
|
||||
else
|
||||
return 0;
|
||||
case TRANS:
|
||||
if (reg_type == SBIT)
|
||||
{
|
||||
int n_outputs = n;
|
||||
auto& args = start;
|
||||
int n_inputs = args.size() - n_outputs;
|
||||
long long res = 0;
|
||||
for (int i = 0; i < n_outputs; i++)
|
||||
res = max(res, args[i] + DIV_CEIL(n_inputs, 64));
|
||||
for (int j = 0; j < n_inputs; j++)
|
||||
res = max(res, args[n_outputs] + DIV_CEIL(n_outputs, 64));
|
||||
return res;
|
||||
}
|
||||
else
|
||||
return 0;
|
||||
default:
|
||||
if (get_reg_type() != reg_type)
|
||||
return 0;
|
||||
@@ -731,7 +746,6 @@ unsigned BaseInstruction::get_max_reg(int reg_type) const
|
||||
case ANDM:
|
||||
case NOTS:
|
||||
case NOTCB:
|
||||
case TRANS:
|
||||
size = DIV_CEIL(n, 64);
|
||||
break;
|
||||
case CONVCBIT2S:
|
||||
|
||||
Reference in New Issue
Block a user