Avoid excessive memory use in matrix multiplication with Beaver-based protocols.

This commit is contained in:
Marcel Keller
2023-02-10 12:23:53 +11:00
parent 21268970ed
commit 59afe5db53
4 changed files with 46 additions and 14 deletions

View File

@@ -36,6 +36,11 @@ class SubProcessor
void resize(size_t size) { C.resize(size); S.resize(size); }
void matmulsm_prep(int i, const CheckVector<T>& source,
const vector<int>& dim, size_t a, size_t b);
void matmulsm_finalize(int i, const vector<int>& dim,
typename vector<T>::iterator C);
template<class sint, class sgf2n> friend class Processor;
template<class U> friend class SPDZ;
template<class U> friend class ProtocolBase;

View File

@@ -540,27 +540,50 @@ void SubProcessor<T>::matmulsm(const CheckVector<T>& source,
assert(C + dim[0] * dim[2] <= S.end());
assert(Proc);
int base = 0;
protocol.init_dotprod();
for (int i = 0; i < dim[0]; i++)
{
auto ii = Proc->get_Ci().at(dim[3] + i);
for (int j = 0; j < dim[2]; j++)
matmulsm_prep(i, source, dim, a, b);
if (protocol.get_buffer_size() > OnlineOptions::singleton.batch_size)
{
auto jj = Proc->get_Ci().at(dim[6] + j);
for (int k = 0; k < dim[1]; k++)
{
auto kk = Proc->get_Ci().at(dim[4] + k);
auto ll = Proc->get_Ci().at(dim[5] + k);
protocol.prepare_dotprod(source.at(a + ii * dim[7] + kk),
source.at(b + ll * dim[8] + jj));
}
protocol.next_dotprod();
protocol.exchange();
for (int j = base; j <= i; j++)
matmulsm_finalize(j, dim, C);
base = i + 1;
protocol.init_dotprod();
}
}
protocol.exchange();
for (int i = 0; i < dim[0]; i++)
for (int j = 0; j < dim[2]; j++)
*(C + i * dim[2] + j) = protocol.finalize_dotprod(dim[1]);
for (int i = base; i < dim[0]; i++)
matmulsm_finalize(i, dim, C);
}
template<class T>
void SubProcessor<T>::matmulsm_prep(int i, const CheckVector<T>& source,
const vector<int>& dim, size_t a, size_t b)
{
auto ii = Proc->get_Ci().at(dim[3] + i);
for (int j = 0; j < dim[2]; j++)
{
auto jj = Proc->get_Ci().at(dim[6] + j);
for (int k = 0; k < dim[1]; k++)
{
auto kk = Proc->get_Ci().at(dim[4] + k);
auto ll = Proc->get_Ci().at(dim[5] + k);
protocol.prepare_dotprod(source.at(a + ii * dim[7] + kk),
source.at(b + ll * dim[8] + jj));
}
protocol.next_dotprod();
}
}
template<class T>
void SubProcessor<T>::matmulsm_finalize(int i, const vector<int>& dim,
typename vector<T>::iterator C)
{
for (int j = 0; j < dim[2]; j++)
*(C + i * dim[2] + j) = protocol.finalize_dotprod(dim[1]);
}
template<class T>