mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-14 16:08:04 -05:00
144 lines
3.9 KiB
C++
144 lines
3.9 KiB
C++
/*
|
|
* MMO.cpp
|
|
*
|
|
*
|
|
*/
|
|
|
|
#include "MMO.h"
|
|
#include "Math/gf2n.h"
|
|
#include "Math/gfp.h"
|
|
#include "Math/bigint.h"
|
|
#include "Math/Z2k.h"
|
|
#include "Math/BitVec.h"
|
|
#include <unistd.h>
|
|
|
|
|
|
void MMO::zeroIV()
|
|
{
|
|
if (N_KEYS > (1 << 8))
|
|
throw not_implemented();
|
|
for (int i = 0; i < N_KEYS; i++)
|
|
{
|
|
octet key[AES_BLK_SIZE];
|
|
memset(key, 0, AES_BLK_SIZE * sizeof(octet));
|
|
key[i] = i;
|
|
setIV(i, key);
|
|
}
|
|
}
|
|
|
|
|
|
void MMO::setIV(int i, octet key[AES_BLK_SIZE])
|
|
{
|
|
aes_schedule(IV[i],key);
|
|
}
|
|
|
|
|
|
template<int N>
|
|
void MMO::encrypt_and_xor(void* output, const void* input, const octet* key,
|
|
const int* indices)
|
|
{
|
|
__m128i in[N], out[N];
|
|
for (int i = 0; i < N; i++)
|
|
in[i] = _mm_loadu_si128(((__m128i*)input) + indices[i]);
|
|
encrypt_and_xor<N>(out, in, key);
|
|
for (int i = 0; i < N; i++)
|
|
_mm_storeu_si128(((__m128i*)output) + indices[i], out[i]);
|
|
}
|
|
|
|
template <int N, int N_BYTES>
|
|
void MMO::hashBlocks(void* output, const void* input, size_t alloc_size)
|
|
{
|
|
size_t used_size = N_BYTES;
|
|
int n_blocks = DIV_CEIL(used_size, 16);
|
|
if (n_blocks > N_KEYS)
|
|
throw runtime_error("not enough MMO keys");
|
|
__m128i tmp[N];
|
|
size_t block_size = sizeof(tmp[0]);
|
|
for (int i = 0; i < n_blocks; i++)
|
|
{
|
|
encrypt_and_xor<N>(tmp, input, IV[i]);
|
|
for (int j = 0; j < N; j++)
|
|
memcpy((char*)output + j * alloc_size + i * block_size, &tmp[j],
|
|
min(used_size - i * block_size, block_size));
|
|
}
|
|
}
|
|
|
|
template <class T, int N>
|
|
void MMO::hashBlocks(void* output, const void* input)
|
|
{
|
|
hashBlocks<N, T::N_BYTES>(output, input, sizeof(T));
|
|
for (int j = 0; j < N; j++)
|
|
((T*)output + j)->normalize();
|
|
}
|
|
|
|
template <>
|
|
void MMO::hashBlocks<gfp1, 1>(void* output, const void* input)
|
|
{
|
|
if (gfp1::get_ZpD().get_t() != 2)
|
|
throw not_implemented();
|
|
encrypt_and_xor<1>(output, input, IV[0]);
|
|
while (mpn_cmp((mp_limb_t*)output, gfp1::get_ZpD().get_prA(), gfp1::t()) >= 0)
|
|
encrypt_and_xor<1>(output, output, IV[0]);
|
|
}
|
|
|
|
template <int X, int L>
|
|
void MMO::hashEightGfp(void* output, const void* input)
|
|
{
|
|
if (gfp_<X, L>::get_ZpD().get_t() < 2)
|
|
throw not_implemented();
|
|
gfp_<X, L>* out = (gfp_<X, L>*)output;
|
|
hashBlocks<8, gfp_<X, L>::N_BYTES>(output, input, sizeof(gfp_<X, L>));
|
|
for (int i = 0; i < 8; i++)
|
|
out[i].zero_overhang();
|
|
int left = 8;
|
|
int indices[8] = {0, 1, 2, 3, 4, 5, 6, 7};
|
|
while (left)
|
|
{
|
|
int now_left = 0;
|
|
for (int j = 0; j < left; j++)
|
|
if (mpn_cmp((mp_limb_t*) out[indices[j]].get_ptr(),
|
|
gfp_<X, L>::get_ZpD().get_prA(), gfp_<X, L>::t()) >= 0)
|
|
{
|
|
indices[now_left] = indices[j];
|
|
now_left++;
|
|
}
|
|
left = now_left;
|
|
|
|
int block_size = sizeof(__m128i);
|
|
int n_blocks = DIV_CEIL(gfp_<X, L>::size(), block_size);
|
|
for (int i = 0; i < n_blocks; i++)
|
|
for (int j = 0; j < left; j++)
|
|
{
|
|
__m128i* addr = (__m128i*) out[indices[j]].get_ptr() + i;
|
|
__m128i* in = (__m128i*) out[indices[j]].get_ptr();
|
|
auto tmp = aes_128_encrypt(_mm_loadu_si128(in), IV[i]);
|
|
memcpy(addr, &tmp, min(block_size, gfp_<X, L>::size() - i * block_size));
|
|
out[indices[j]].zero_overhang();
|
|
}
|
|
}
|
|
}
|
|
|
|
template <>
|
|
void MMO::hashBlocks<gfp1, 8>(void* output, const void* input)
|
|
{
|
|
hashEightGfp<1, GFP_MOD_SZ>(output, input);
|
|
}
|
|
|
|
template <>
|
|
void MMO::hashBlocks<gfp3, 8>(void* output, const void* input)
|
|
{
|
|
hashEightGfp<3, 4>(output, input);
|
|
}
|
|
|
|
#define ZZ(F,N) \
|
|
template void MMO::hashBlocks<F,N>(void*, const void*);
|
|
#define Z(F) ZZ(F,1) ZZ(F,2) ZZ(F,8)
|
|
Z(gf2n_long) Z(Z2<64>) Z(Z2<112>) Z(Z2<128>) Z(Z2<160>) Z(Z2<114>) Z(Z2<130>)
|
|
Z(Z2<72>)
|
|
Z(SignedZ2<64>) Z(SignedZ2<72>)
|
|
Z(gf2n_short)
|
|
Z(BitVec)
|
|
Z(Z2<41>)
|
|
Z(Z2<120>) Z(Z2<122>) Z(Z2<136>) Z(Z2<138>)
|
|
Z(Z2<65>) Z(Z2<49>)
|