mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-08 21:18:03 -05:00
Make MATMULSM mergeable
This commit is contained in:
@@ -323,8 +323,8 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
|
||||
get_vector(num_var_args, start, s);
|
||||
break;
|
||||
case MATMULSM:
|
||||
get_ints(r, s, 3);
|
||||
get_vector(9, start, s);
|
||||
num_var_args = get_int(s);
|
||||
get_vector(num_var_args, start, s);
|
||||
break;
|
||||
|
||||
// read from file, input is opcode num_args,
|
||||
@@ -1117,8 +1117,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
|
||||
Proc.Procp.matmuls(Proc.Procp.get_S(), *this);
|
||||
return;
|
||||
case MATMULSM:
|
||||
Proc.Procp.protocol.matmulsm(Proc.Procp, Proc.machine.Mp.MS, *this,
|
||||
Proc.read_Ci(r[1]), Proc.read_Ci(r[2]));
|
||||
Proc.Procp.protocol.matmulsm(Proc.Procp, Proc.machine.Mp.MS, *this);
|
||||
return;
|
||||
case CONV2DS:
|
||||
Proc.Procp.protocol.conv2ds(Proc.Procp, *this);
|
||||
|
||||
@@ -77,8 +77,12 @@ public:
|
||||
void mulrs(const vector<int>& reg);
|
||||
void dotprods(const vector<int>& reg, int size);
|
||||
void matmuls(const vector<T>& source, const Instruction& instruction);
|
||||
void matmulsm(const MemoryPart<T>& source, const Instruction& instruction, size_t a,
|
||||
size_t b);
|
||||
void matmulsm(const MemoryPart<T>& source, const Instruction& instruction);
|
||||
|
||||
void matmulsm_finalize_batch(vector<int>::const_iterator startMatmul, int startI, int startJ,
|
||||
vector<int>::const_iterator endMatmul,
|
||||
int endI, int endJ);
|
||||
|
||||
void conv2ds(const Instruction& instruction);
|
||||
|
||||
void secure_shuffle(const Instruction& instruction);
|
||||
|
||||
@@ -601,73 +601,156 @@ void SubProcessor<T>::matmuls(const vector<T>& source,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::matmulsm(const MemoryPart<T>& source,
|
||||
const Instruction& instruction, size_t a, size_t b)
|
||||
const Instruction& instruction)
|
||||
{
|
||||
auto& dim = instruction.get_start();
|
||||
auto C = S.begin() + (instruction.get_r(0));
|
||||
assert(C + dim[0] * dim[2] <= S.end());
|
||||
assert(Proc);
|
||||
|
||||
int base = 0;
|
||||
int base2 = 0;
|
||||
auto& start = instruction.get_start();
|
||||
|
||||
auto batchStartMatrix = start.begin();
|
||||
int batchStartI = 0;
|
||||
int batchStartJ = 0;
|
||||
|
||||
size_t sourceSize = source.size();
|
||||
const T* sourceData = source.data();
|
||||
|
||||
protocol.init_dotprod();
|
||||
for (int i = 0; i < dim[0]; i++)
|
||||
{
|
||||
auto ii = Proc->get_Ci().at(dim[3] + i).get();
|
||||
for (int j = 0; j < dim[2]; j++)
|
||||
{
|
||||
#ifdef DEBUG_MATMULSM
|
||||
cerr << "matmulsm prep " << i << " " << j << endl;
|
||||
for (auto matmulArgs = start.begin(); matmulArgs < start.end(); matmulArgs += 12) {
|
||||
auto output = S.begin() + matmulArgs[0];
|
||||
size_t firstFactorBase = Proc->get_Ci().at(matmulArgs[1]).get();
|
||||
size_t secondFactorBase = Proc->get_Ci().at(matmulArgs[2]).get();
|
||||
auto resultNumberOfRows = matmulArgs[3];
|
||||
auto usedNumberOfFirstFactorColumns = matmulArgs[4];
|
||||
auto resultNumberOfColumns = matmulArgs[5];
|
||||
auto firstFactorTotalNumberOfColumns = matmulArgs[10];
|
||||
auto secondFactorTotalNumberOfColumns = matmulArgs[11];
|
||||
|
||||
assert(output + resultNumberOfRows * resultNumberOfColumns <= S.end());
|
||||
|
||||
for (int i = 0; i < resultNumberOfRows; i += 1) {
|
||||
auto actualFirstFactorRow = Proc->get_Ci().at(matmulArgs[6] + i).get();
|
||||
|
||||
for (int j = 0; j < resultNumberOfColumns; j += 1) {
|
||||
auto actualSecondFactorColumn = Proc->get_Ci().at(matmulArgs[9] + j).get();
|
||||
|
||||
#ifdef MATMULSM_DEBUG
|
||||
cout << "Preparing " << i << "," << j << "(buffer size: " << protocol.get_buffer_size() << ")" << endl;
|
||||
#endif
|
||||
matmulsm_prep(ii, j, source, dim, a, b);
|
||||
if (protocol.get_buffer_size() > OnlineOptions::singleton.batch_size)
|
||||
{
|
||||
#ifdef DEBUG_MATMULSM
|
||||
cerr << "matmulsm round " << protocol.get_buffer_size() << endl;
|
||||
#endif
|
||||
protocol.exchange();
|
||||
if (base < i)
|
||||
for (int l = base2; l < dim[2]; l++)
|
||||
matmulsm_finalize(base, l, dim, C);
|
||||
for (int k = base + 1; k < i; k++)
|
||||
for (int l = 0; l < dim[2]; l++)
|
||||
matmulsm_finalize(k, l, dim, C);
|
||||
for (int l = base < i ? 0 : base2; l <= j; l++)
|
||||
matmulsm_finalize(i, l, dim, C);
|
||||
base = i;
|
||||
base2 = j + 1;
|
||||
protocol.init_dotprod();
|
||||
|
||||
for (int k = 0; k < usedNumberOfFirstFactorColumns; k += 1) {
|
||||
auto actualFirstFactorColumn = Proc->get_Ci().at(matmulArgs[7] + k).get();
|
||||
auto actualSecondFactorRow = Proc->get_Ci().at(matmulArgs[8] + k).get();
|
||||
|
||||
auto firstAddress = firstFactorBase + actualFirstFactorRow * firstFactorTotalNumberOfColumns + actualFirstFactorColumn;
|
||||
auto secondAddress = secondFactorBase + actualSecondFactorRow * secondFactorTotalNumberOfColumns + actualSecondFactorColumn;
|
||||
|
||||
assert(firstAddress < sourceSize);
|
||||
assert(secondAddress < sourceSize);
|
||||
|
||||
protocol.prepare_dotprod(sourceData[firstAddress], sourceData[secondAddress]);
|
||||
}
|
||||
protocol.next_dotprod();
|
||||
|
||||
if (protocol.get_buffer_size() > OnlineOptions::singleton.batch_size) {
|
||||
protocol.exchange();
|
||||
|
||||
matmulsm_finalize_batch(batchStartMatrix, batchStartI, batchStartJ,
|
||||
matmulArgs, i, j);
|
||||
batchStartMatrix = matmulArgs;
|
||||
batchStartI = i;
|
||||
batchStartJ = j + 1;
|
||||
|
||||
protocol.init_dotprod();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
protocol.exchange();
|
||||
for (int j = base2; j < dim[2]; j++)
|
||||
matmulsm_finalize(base, j, dim, C);
|
||||
for (int i = base + 1; i < dim[0]; i++)
|
||||
for (int j = 0; j < dim[2]; j++)
|
||||
matmulsm_finalize(i, j, dim, C);
|
||||
auto lastMatmulsArgs = start.end() - 12;
|
||||
auto lastMatrixRows = lastMatmulsArgs[3];
|
||||
auto lastMatrixColumns = lastMatmulsArgs[5];
|
||||
matmulsm_finalize_batch(batchStartMatrix, batchStartI, batchStartJ,
|
||||
lastMatmulsArgs, lastMatrixRows - 1, lastMatrixColumns - 1);
|
||||
}
|
||||
|
||||
template<class T>
|
||||
void SubProcessor<T>::matmulsm_prep(int ii, int j, const MemoryPart<T>& source,
|
||||
const vector<int>& dim, size_t a, size_t b)
|
||||
{
|
||||
auto jj = Proc->get_Ci().at(dim[6] + j).get();
|
||||
const T* base = source.data();
|
||||
size_t size = source.size();
|
||||
for (int k = 0; k < dim[1]; k++)
|
||||
{
|
||||
auto kk = Proc->get_Ci().at(dim[4] + k).get();
|
||||
auto ll = Proc->get_Ci().at(dim[5] + k).get();
|
||||
auto aa = a + ii * dim[7] + kk;
|
||||
auto bb = b + ll * dim[8] + jj;
|
||||
assert(aa < size);
|
||||
assert(bb < size);
|
||||
protocol.prepare_dotprod(base[aa], base[bb]);
|
||||
void SubProcessor<T>::matmulsm_finalize_batch(vector<int>::const_iterator startMatmul, int startI, int startJ,
|
||||
vector<int>::const_iterator endMatmul, int endI, int endJ) {
|
||||
|
||||
for (auto matmulArgs = startMatmul; matmulArgs <= endMatmul; matmulArgs += 12) {
|
||||
auto output = S.begin() + matmulArgs[0];
|
||||
auto resultNumberOfRows = matmulArgs[3];
|
||||
auto usedNumberOfFirstFactorColumns = matmulArgs[4];
|
||||
auto resultNumberOfColumns = matmulArgs[5];
|
||||
|
||||
assert(output + resultNumberOfRows * resultNumberOfColumns <= S.end());
|
||||
|
||||
// Finish the first unfinished row in the current matrix.
|
||||
int firstRowEndJ = resultNumberOfColumns - 1;
|
||||
if (matmulArgs == endMatmul && startI == endI) // For the case that the batch covers only a part of the first row of current matrix or only part of a single row.
|
||||
firstRowEndJ = endJ;
|
||||
#ifdef MATMULSM_DEBUG
|
||||
cout << "Batch is in single row " << endJ << endl;
|
||||
#endif
|
||||
for (int j = startJ; j <= firstRowEndJ; j += 1) {
|
||||
#ifdef MATMULSM_DEBUG
|
||||
cout << "Finalizing (first row) " << startI << "," << j << endl;
|
||||
#endif
|
||||
*(output + startI * resultNumberOfColumns + j) = protocol.finalize_dotprod(usedNumberOfFirstFactorColumns);
|
||||
}
|
||||
if (firstRowEndJ == resultNumberOfColumns - 1) {
|
||||
startJ = 0;
|
||||
startI += 1;
|
||||
}
|
||||
else {
|
||||
// The whole batch covers only a part of a single row.
|
||||
startJ = endJ + 1;
|
||||
}
|
||||
|
||||
// Determine the point up until which the batch runs in the current matrix.
|
||||
int currentMatrixEndI = resultNumberOfRows - 1;
|
||||
int currentMatrixEndJ = resultNumberOfColumns - 1;
|
||||
if (matmulArgs == endMatmul) {
|
||||
currentMatrixEndI = endI;
|
||||
currentMatrixEndJ = endJ;
|
||||
}
|
||||
|
||||
// Finish the rows that always are complete, i.e., the second to the "second to last" row.
|
||||
for (; startI <= currentMatrixEndI - 1; startI += 1) {
|
||||
for (int j = 0; j < resultNumberOfColumns; j += 1) {
|
||||
#ifdef MATMULSM_DEBUG
|
||||
cout << "Finalizing (main part) " << startI << "," << j << endl;
|
||||
#endif
|
||||
*(output + startI * resultNumberOfColumns + j) = protocol.finalize_dotprod(usedNumberOfFirstFactorColumns);
|
||||
}
|
||||
}
|
||||
|
||||
// (Partially) finish the last row.
|
||||
if (startI == currentMatrixEndI) {
|
||||
for (; startJ <= currentMatrixEndJ; startJ += 1) {
|
||||
#ifdef MATMULSM_DEBUG
|
||||
cout << "Finalizing (last row) " << startI << "," << startJ << endl;
|
||||
#endif
|
||||
*(output + startI * resultNumberOfColumns + startJ) = protocol.finalize_dotprod(usedNumberOfFirstFactorColumns);
|
||||
}
|
||||
}
|
||||
else {
|
||||
#ifdef MATMULSM_DEBUG
|
||||
// This happens when there is only one row.
|
||||
cout << "Skipping final row of matrix because it was handled previously." << endl;
|
||||
#endif
|
||||
}
|
||||
|
||||
if (matmulArgs < endMatmul) {
|
||||
// Reset startI and startJ to the beginning of the matrix.
|
||||
startI = 0;
|
||||
startJ = 0;
|
||||
}
|
||||
}
|
||||
protocol.next_dotprod();
|
||||
}
|
||||
|
||||
template<class T>
|
||||
|
||||
Reference in New Issue
Block a user