diff --git a/Processor/Processor.h b/Processor/Processor.h index 3e5d405b..414aeb93 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -36,9 +36,9 @@ class SubProcessor void resize(size_t size) { C.resize(size); S.resize(size); } - void matmulsm_prep(int i, const CheckVector& source, + void matmulsm_prep(int ii, int j, const CheckVector& source, const vector& dim, size_t a, size_t b); - void matmulsm_finalize(int i, const vector& dim, + void matmulsm_finalize(int i, int j, const vector& dim, typename vector::iterator C); template friend class Processor; diff --git a/Processor/Processor.hpp b/Processor/Processor.hpp index c2c5a9f2..1d5a0d3b 100644 --- a/Processor/Processor.hpp +++ b/Processor/Processor.hpp @@ -583,49 +583,68 @@ void SubProcessor::matmulsm(const CheckVector& source, assert(Proc); int base = 0; + int base2 = 0; protocol.init_dotprod(); for (int i = 0; i < dim[0]; i++) { - matmulsm_prep(i, source, dim, a, b); - if (protocol.get_buffer_size() > OnlineOptions::singleton.batch_size) + auto ii = Proc->get_Ci().at(dim[3] + i); + for (int j = 0; j < dim[2]; j++) { - protocol.exchange(); - for (int j = base; j <= i; j++) - matmulsm_finalize(j, dim, C); - base = i + 1; - protocol.init_dotprod(); +#ifdef DEBUG_MATMULSM + cerr << "matmulsm prep " << i << " " << j << endl; +#endif + matmulsm_prep(ii, j, source, dim, a, b); + if (protocol.get_buffer_size() > OnlineOptions::singleton.batch_size) + { +#ifdef DEBUG_MATMULSM + cerr << "matmulsm round " << protocol.get_buffer_size() << endl; +#endif + protocol.exchange(); + if (base < i) + for (int l = base2; l < dim[2]; l++) + matmulsm_finalize(base, l, dim, C); + for (int k = base + 1; k < i; k++) + for (int l = 0; l < dim[2]; l++) + matmulsm_finalize(k, l, dim, C); + for (int l = base < i ? 0 : base2; l <= j; l++) + matmulsm_finalize(i, l, dim, C); + base = i; + base2 = j + 1; + protocol.init_dotprod(); + } } } protocol.exchange(); - for (int i = base; i < dim[0]; i++) - matmulsm_finalize(i, dim, C); + for (int j = base2; j < dim[2]; j++) + matmulsm_finalize(base, j, dim, C); + for (int i = base + 1; i < dim[0]; i++) + for (int j = 0; j < dim[2]; j++) + matmulsm_finalize(i, j, dim, C); } template -void SubProcessor::matmulsm_prep(int i, const CheckVector& source, +void SubProcessor::matmulsm_prep(int ii, int j, const CheckVector& source, const vector& dim, size_t a, size_t b) { - auto ii = Proc->get_Ci().at(dim[3] + i); - for (int j = 0; j < dim[2]; j++) + auto jj = Proc->get_Ci().at(dim[6] + j); + for (int k = 0; k < dim[1]; k++) { - auto jj = Proc->get_Ci().at(dim[6] + j); - for (int k = 0; k < dim[1]; k++) - { - auto kk = Proc->get_Ci().at(dim[4] + k); - auto ll = Proc->get_Ci().at(dim[5] + k); - protocol.prepare_dotprod(source.at(a + ii * dim[7] + kk), - source.at(b + ll * dim[8] + jj)); - } - protocol.next_dotprod(); + auto kk = Proc->get_Ci().at(dim[4] + k); + auto ll = Proc->get_Ci().at(dim[5] + k); + protocol.prepare_dotprod(source.at(a + ii * dim[7] + kk), + source.at(b + ll * dim[8] + jj)); } + protocol.next_dotprod(); } template -void SubProcessor::matmulsm_finalize(int i, const vector& dim, +void SubProcessor::matmulsm_finalize(int i, int j, const vector& dim, typename vector::iterator C) { - for (int j = 0; j < dim[2]; j++) - *(C + i * dim[2] + j) = protocol.finalize_dotprod(dim[1]); +#ifdef DEBUG_MATMULSM + cerr << "matmulsm finalize " << i << " " << j << endl; +#endif + *(C + i * dim[2] + j) = protocol.finalize_dotprod(dim[1]); } template