mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-10 14:08:09 -05:00
149 lines
4.5 KiB
C++
149 lines
4.5 KiB
C++
#include "OTExtension.h"
|
|
#include "OTExtensionWithMatrix.h"
|
|
|
|
#include "OT/Tools.h"
|
|
#include "Math/gf2n.h"
|
|
#include "Tools/aes.h"
|
|
#include "Tools/MMO.h"
|
|
#include <wmmintrin.h>
|
|
#include <emmintrin.h>
|
|
|
|
|
|
OTExtension::OTExtension(const BaseOT& baseOT, TwoPartyPlayer* player,
|
|
bool passive) : player(player)
|
|
{
|
|
ot_role = INV_ROLE(baseOT.ot_role);
|
|
passive_only = passive;
|
|
init(baseOT.receiver_inputs, baseOT.sender_inputs, baseOT.receiver_outputs);
|
|
}
|
|
|
|
|
|
// test if a == b
|
|
int eq_m128i(__m128i a, __m128i b)
|
|
{
|
|
__m128i vcmp = _mm_cmpeq_epi8(a, b);
|
|
uint16_t vmask = _mm_movemask_epi8(vcmp);
|
|
return (vmask == 0xffff);
|
|
}
|
|
|
|
|
|
|
|
void OTExtensionWithMatrix::check_correlation(int nOTs,
|
|
const BitVector& receiverInput)
|
|
{
|
|
//cout << "\tStarting correlation check\n" << flush;
|
|
#ifdef OTEXT_TIMER
|
|
timeval startv, endv;
|
|
gettimeofday(&startv, NULL);
|
|
#endif
|
|
if (nbaseOTs != 128)
|
|
{
|
|
cerr << "Correlation check not implemented for length != 128\n";
|
|
throw not_implemented();
|
|
}
|
|
GlobalPRNG G(*player);
|
|
#ifdef OTEXT_TIMER
|
|
gettimeofday(&endv, NULL);
|
|
double elapsed = timeval_diff(&startv, &endv);
|
|
cout << "\t\tCommitment for seed took time " << elapsed/1000000 << endl << flush;
|
|
times["Commitment for seed"] += timeval_diff(&startv, &endv);
|
|
gettimeofday(&startv, NULL);
|
|
#endif
|
|
|
|
__m128i Delta, x128i;
|
|
Delta = _mm_load_si128((__m128i*)&(baseReceiverInput.get_ptr()[0]));
|
|
|
|
BitVector chi(nbaseOTs);
|
|
BitVector x(nbaseOTs);
|
|
__m128i t = _mm_setzero_si128();
|
|
__m128i q = _mm_setzero_si128();
|
|
__m128i t2 = _mm_setzero_si128();
|
|
__m128i q2 = _mm_setzero_si128();
|
|
__m128i chii, ti, qi, ti2, qi2;
|
|
x128i = _mm_setzero_si128();
|
|
|
|
for (int i = 0; i < nOTs; i++)
|
|
{
|
|
// chi.randomize(G);
|
|
// chii = _mm_load_si128((__m128i*)&(chi.get_ptr()[0]));
|
|
chii = G.get_doubleword();
|
|
|
|
if (ot_role & RECEIVER)
|
|
{
|
|
if (receiverInput.get_bit(i) == 1)
|
|
{
|
|
x128i = _mm_xor_si128(x128i, chii);
|
|
}
|
|
ti = _mm_loadu_si128((__m128i*)get_receiver_output(i));
|
|
// multiply over polynomial ring to avoid reduction
|
|
mul128(ti, chii, &ti, &ti2);
|
|
t = _mm_xor_si128(t, ti);
|
|
t2 = _mm_xor_si128(t2, ti2);
|
|
}
|
|
if (ot_role & SENDER)
|
|
{
|
|
qi = _mm_loadu_si128((__m128i*)(get_sender_output(0, i)));
|
|
mul128(qi, chii, &qi, &qi2);
|
|
q = _mm_xor_si128(q, qi);
|
|
q2 = _mm_xor_si128(q2, qi2);
|
|
}
|
|
}
|
|
#ifdef OTEXT_DEBUG
|
|
if (ot_role & RECEIVER)
|
|
{
|
|
cout << "\tSending x,t\n";
|
|
cout << "\tsend x = " << __m128i_toString<octet>(x128i) << endl;
|
|
cout << "\tsend t = " << __m128i_toString<octet>(t) << endl;
|
|
cout << "\tsend t2 = " << __m128i_toString<octet>(t2) << endl;
|
|
}
|
|
#endif
|
|
check_iteration(Delta, q, q2, t, t2, x128i);
|
|
#ifdef OTEXT_TIMER
|
|
gettimeofday(&endv, NULL);
|
|
elapsed = timeval_diff(&startv, &endv);
|
|
cout << "\t\tChecking correlation took time " << elapsed/1000000 << endl << flush;
|
|
times["Checking correlation"] += timeval_diff(&startv, &endv);
|
|
#endif
|
|
}
|
|
|
|
void OTExtensionWithMatrix::check_iteration(__m128i delta, __m128i q, __m128i q2,
|
|
__m128i t, __m128i t2, __m128i x)
|
|
{
|
|
vector<octetStream> os(2);
|
|
// send x, t;
|
|
__m128i received_t, received_t2, received_x, tmp1, tmp2;
|
|
if (ot_role & RECEIVER)
|
|
{
|
|
os[0].append((octet*)&x, sizeof(x));
|
|
os[0].append((octet*)&t, sizeof(t));
|
|
os[0].append((octet*)&t2, sizeof(t2));
|
|
}
|
|
send_if_ot_receiver(player, os, ot_role);
|
|
|
|
if (ot_role & SENDER)
|
|
{
|
|
os[1].consume((octet*)&received_x, sizeof(received_x));
|
|
os[1].consume((octet*)&received_t, sizeof(received_t));
|
|
os[1].consume((octet*)&received_t2, sizeof(received_t2));
|
|
|
|
// check t = x * Delta + q
|
|
//gfmul128(received_x, delta, &tmp1);
|
|
mul128(received_x, delta, &tmp1, &tmp2);
|
|
tmp1 = _mm_xor_si128(tmp1, q);
|
|
tmp2 = _mm_xor_si128(tmp2, q2);
|
|
|
|
if (eq_m128i(tmp1, received_t) && eq_m128i(tmp2, received_t2))
|
|
{
|
|
//cout << "\tCheck passed\n";
|
|
}
|
|
else
|
|
{
|
|
cerr << "Correlation check failed\n";
|
|
cout << "rec t = " << __m128i_toString<octet>(received_t) << endl;
|
|
cout << "tmp1 = " << __m128i_toString<octet>(tmp1) << endl;
|
|
cout << "q = " << __m128i_toString<octet>(q) << endl;
|
|
throw runtime_error("correlation check");
|
|
}
|
|
}
|
|
}
|