mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-10 05:57:57 -05:00
Reduce memory usage in matrix multiplication with Beaver-based protocols.
This commit is contained in:
@@ -36,9 +36,9 @@ class SubProcessor
|
||||
|
||||
void resize(size_t size) { C.resize(size); S.resize(size); }
|
||||
|
||||
void matmulsm_prep(int i, const CheckVector<T>& source,
|
||||
void matmulsm_prep(int ii, int j, const CheckVector<T>& source,
|
||||
const vector<int>& dim, size_t a, size_t b);
|
||||
void matmulsm_finalize(int i, const vector<int>& dim,
|
||||
void matmulsm_finalize(int i, int j, const vector<int>& dim,
|
||||
typename vector<T>::iterator C);
|
||||
|
||||
template<class sint, class sgf2n> friend class Processor;
|
||||
|
||||
@@ -583,49 +583,68 @@ void SubProcessor<T>::matmulsm(const CheckVector<T>& source,
|
||||
assert(Proc);
|
||||
|
||||
int base = 0;
|
||||
int base2 = 0;
|
||||
protocol.init_dotprod();
|
||||
for (int i = 0; i < dim[0]; i++)
|
||||
{
|
||||
matmulsm_prep(i, source, dim, a, b);
|
||||
if (protocol.get_buffer_size() > OnlineOptions::singleton.batch_size)
|
||||
auto ii = Proc->get_Ci().at(dim[3] + i);
|
||||
for (int j = 0; j < dim[2]; j++)
|
||||
{
|
||||
protocol.exchange();
|
||||
for (int j = base; j <= i; j++)
|
||||
matmulsm_finalize(j, dim, C);
|
||||
base = i + 1;
|
||||
protocol.init_dotprod();
|
||||
#ifdef DEBUG_MATMULSM
|
||||
cerr << "matmulsm prep " << i << " " << j << endl;
|
||||
#endif
|
||||
matmulsm_prep(ii, j, source, dim, a, b);
|
||||
if (protocol.get_buffer_size() > OnlineOptions::singleton.batch_size)
|
||||
{
|
||||
#ifdef DEBUG_MATMULSM
|
||||
cerr << "matmulsm round " << protocol.get_buffer_size() << endl;
|
||||
#endif
|
||||
protocol.exchange();
|
||||
if (base < i)
|
||||
for (int l = base2; l < dim[2]; l++)
|
||||
matmulsm_finalize(base, l, dim, C);
|
||||
for (int k = base + 1; k < i; k++)
|
||||
for (int l = 0; l < dim[2]; l++)
|
||||
matmulsm_finalize(k, l, dim, C);
|
||||
for (int l = base < i ? 0 : base2; l <= j; l++)
|
||||
matmulsm_finalize(i, l, dim, C);
|
||||
base = i;
|
||||
base2 = j + 1;
|
||||
protocol.init_dotprod();
|
||||
}
|
||||
}
|
||||
}
|
||||
protocol.exchange();
|
||||
for (int i = base; i < dim[0]; i++)
|
||||
matmulsm_finalize(i, dim, C);
|
||||
for (int j = base2; j < dim[2]; j++)
|
||||
matmulsm_finalize(base, j, dim, C);
|
||||
for (int i = base + 1; i < dim[0]; i++)
|
||||
for (int j = 0; j < dim[2]; j++)
|
||||
matmulsm_finalize(i, j, dim, C);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::matmulsm_prep(int i, const CheckVector<T>& source,
|
||||
void SubProcessor<T>::matmulsm_prep(int ii, int j, const CheckVector<T>& source,
|
||||
const vector<int>& dim, size_t a, size_t b)
|
||||
{
|
||||
auto ii = Proc->get_Ci().at(dim[3] + i);
|
||||
for (int j = 0; j < dim[2]; j++)
|
||||
auto jj = Proc->get_Ci().at(dim[6] + j);
|
||||
for (int k = 0; k < dim[1]; k++)
|
||||
{
|
||||
auto jj = Proc->get_Ci().at(dim[6] + j);
|
||||
for (int k = 0; k < dim[1]; k++)
|
||||
{
|
||||
auto kk = Proc->get_Ci().at(dim[4] + k);
|
||||
auto ll = Proc->get_Ci().at(dim[5] + k);
|
||||
protocol.prepare_dotprod(source.at(a + ii * dim[7] + kk),
|
||||
source.at(b + ll * dim[8] + jj));
|
||||
}
|
||||
protocol.next_dotprod();
|
||||
auto kk = Proc->get_Ci().at(dim[4] + k);
|
||||
auto ll = Proc->get_Ci().at(dim[5] + k);
|
||||
protocol.prepare_dotprod(source.at(a + ii * dim[7] + kk),
|
||||
source.at(b + ll * dim[8] + jj));
|
||||
}
|
||||
protocol.next_dotprod();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::matmulsm_finalize(int i, const vector<int>& dim,
|
||||
void SubProcessor<T>::matmulsm_finalize(int i, int j, const vector<int>& dim,
|
||||
typename vector<T>::iterator C)
|
||||
{
|
||||
for (int j = 0; j < dim[2]; j++)
|
||||
*(C + i * dim[2] + j) = protocol.finalize_dotprod(dim[1]);
|
||||
#ifdef DEBUG_MATMULSM
|
||||
cerr << "matmulsm finalize " << i << " " << j << endl;
|
||||
#endif
|
||||
*(C + i * dim[2] + j) = protocol.finalize_dotprod(dim[1]);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
|
||||
Reference in New Issue
Block a user