Files
MP-SPDZ/GC/square64.cpp
2021-07-27 21:40:29 +10:00

242 lines
6.1 KiB
C++

/*
* square64.cpp
*
*/
#include "square64.h"
#include "Tools/cpu_support.h"
#include "OT/BitMatrix.h"
#include <stdexcept>
#include <iostream>
#include <assert.h>
using namespace std;
union matrix32x8
{
__m256i whole;
octet rows[32];
matrix32x8(const __m256i& x = _mm256_setzero_si256()) : whole(x) {}
matrix32x8(square64& input, int x, int y)
{
for (int l = 0; l < 32; l++)
rows[l] = input.bytes[32*x+l][y];
}
void transpose(square64& output, int x, int y)
{
#if defined(__AVX2__) || !defined(__x86_64__)
if (cpu_has_avx2())
{
for (int j = 0; j < 8; j++)
{
int row = _mm256_movemask_epi8(whole);
whole = _mm256_slli_epi64(whole, 1);
// _mm_movemask_epi8 uses most significant bit, hence +7-j
output.halfrows[8*x+7-j][y] = row;
}
}
else
#endif
{
(void) output, (void) x, (void) y;
throw runtime_error("need AVX2 support");
}
}
};
#ifdef DEBUG_TRANS
ostream& operator<<(ostream& os, const __m256i& x)
{
for (int i = 0; i < 4; i++)
os << hex << " " << ((long*)&x)[i];
os << dec;
return os;
}
#endif
#define ZIP_CASE(I, LOWS, HIGHS, A, B) \
case I: \
LOWS = _mm256_unpacklo_epi##I(A, B); \
HIGHS = _mm256_unpackhi_epi##I(A, B); \
break;
void zip(int chunk_size, __m256i& lows, __m256i& highs,
const __m256i& a, const __m256i& b)
{
#if defined(__AVX2__) || !defined(__x86_64__)
if (cpu_has_avx2())
{
switch (chunk_size)
{
ZIP_CASE(8, lows, highs, a, b);
ZIP_CASE(16, lows, highs, a, b);
ZIP_CASE(32, lows, highs, a, b);
ZIP_CASE(64, lows, highs, a, b);
case 128:
lows = a;
highs = b;
swap(((__m128i*)&lows)[1], ((__m128i*)&highs)[0]);
break;
default:
throw invalid_argument("not supported");
}
}
else
#endif
{
(void) chunk_size, (void) lows, (void) highs, (void) a, (void) b;
throw runtime_error("need AVX2 support");
}
}
void square64::transpose(int n_rows, int n_cols)
{
#ifdef DEBUG_TRANS
cout << "transpose" << endl;
print();
#endif
assert(n_rows <= 64);
assert(n_cols <= 64);
#ifndef __AVX2__
square128 tmp2;
tmp2.set_zero();
for (int i = 0; i < n_rows; i++)
tmp2.rows[i] = _mm_cvtsi64_si128(rows[i]);
tmp2.transpose();
*this = {};
for (int i = 0; i < n_cols; i++)
rows[i] = _mm_cvtsi128_si64(tmp2.rows[i]);
return;
#endif
square64 tmp = *this;
*this = {};
for (int k = 0; k < DIV_CEIL(n_rows, 32); k++)
{
__m256i x[8], lows[4], highs[4];
memcpy(x, &tmp.quadrows[8 * k], sizeof(x));
#ifdef DEBUG_TRANS
for (int j = 0; j < 8; j++)
if (not _mm256_testz_si256(x[j], x[j]))
{
cout << "transpose k " << k << " j " << j << ": ";
for (int i = 0; i < 4; i++)
cout << hex << " " << ((long*)&x[j])[i];
cout << dec << endl;
}
#endif
for (int chunk_size = 128; chunk_size >= 64; chunk_size /= 2)
{
for (int j = 0; j < 4; j ++)
{
int a, b;
if (chunk_size > 64)
{
a = j;
b = a + 4;
}
else if (chunk_size == 64)
{
a = j / 2 * 2 + j;
b = a + 2;
}
else
{
a = 2 * j;
b = a + 1;
}
zip(chunk_size, lows[j], highs[j], x[a], x[b]);
}
memcpy(x, lows, sizeof(lows));
memcpy(&x[4], highs, sizeof(highs));
#ifdef DEBUG_TRANS
for (int j = 0; j < 8; j++)
if (not _mm256_testz_si256(x[j], x[j]))
{
cout << "transpose k " << k << " chunk " << chunk_size
<< " j " << j << ": ";
for (int i = 0; i < 4; i++)
cout << hex << " " << ((long*)&x[j])[i];
cout << dec << endl;
}
#endif
}
for (int chunk_size = 8; chunk_size < 128; chunk_size *= 2)
{
for (int j = 0; j < 4; j ++)
{
int a = j / 2 * 2 + j;
int b = a + 2;
if (chunk_size == 8)
{
a = j;
b = j + 4;
}
if (chunk_size == 64)
{
a = 2 * j;
b = a + 1;
}
if (chunk_size == 32)
{
a = 2 * j;
b = a + 1;
}
zip(chunk_size, lows[j], highs[j], x[a], x[b]);
}
memcpy(x, lows, sizeof(lows));
memcpy(&x[4], highs, sizeof(highs));
#ifdef DEBUG_TRANS
for (int j = 0; j < 8; j++)
if (not _mm256_testz_si256(x[j], x[j]))
{
cout << "transpose k " << k << " chunk " << chunk_size
<< " j " << j << ": ";
for (int i = 0; i < 4; i++)
cout << hex << " " << ((long*)&x[j])[i];
cout << dec << endl;
}
#endif
}
int perm[] = { 0, 4, 2, 6, 1, 5, 3, 7 };
for (int i = 0; i < DIV_CEIL(n_cols, 8); i++)
{
matrix32x8(x[perm[i]]).transpose(*this, i, k);
}
}
#ifdef DEBUG_TRANS
cout << "after transpose" << endl;
print();
#endif
}
bool square64::operator !=(const square64& other)
{
for (int i = 0; i < 64; i++)
if (rows[i] != other.rows[i])
return false;
return true;
}
void square64::print()
{
for (int i = 0; i < 64; i++)
{
for (int j = 0; j < 64; j++)
cout << get_bit(i, j);
cout << endl;
}
cout << flush;
}