Files
MP-SPDZ/FHE/FFT.cpp
Marcel Keller 6cc3fccef0 Maintenance.
2023-05-09 14:50:53 +10:00

262 lines
6.3 KiB
C++

#include "FHE/FFT.h"
#include "Math/Zp_Data.h"
#include "Processor/BaseMachine.h"
#include "Math/modp.hpp"
/* Computes the FFT via Horner's Rule
theta is assumed to be an Nth root of unity
*/
void NaiveFFT(vector<modp>& ans,vector<modp>& a,int N,const modp& theta,const Zp_Data& PrD)
{
int i,j;
modp thetaPow;
assignOne(thetaPow,PrD);
for (i=0; i<N; i++)
{ ans[i]=a[N-1];
for (j=N-2; j>=0; j--)
{ Mul(ans[i],ans[i],thetaPow,PrD);
Add(ans[i],ans[i],a[j],PrD);
}
Mul(thetaPow,thetaPow,theta,PrD);
}
}
void FFT(vector<modp>& a,int N,const modp& theta,const Zp_Data& PrD)
{
if (N==1) { return; }
if (N<5)
{ vector<modp> b(N);
NaiveFFT(b,a,N,theta,PrD);
a=b;
return;
}
vector<modp> a0(N/2),a1(N/2);
int i;
for (i=0; i<N/2; i++)
{ a0[i]=a[2*i];
a1[i]=a[2*i+1];
}
modp theta2,w,t;
Sqr(theta2,theta,PrD);
FFT(a0,N/2,theta2,PrD);
FFT(a1,N/2,theta2,PrD);
assignOne(w,PrD);
for (i=0; i<N/2; i++)
{ Mul(t,w,a1[i],PrD);
Add(a[i],a0[i],t,PrD);
Sub(a[i+N/2],a0[i],t,PrD);
Mul(w,w,theta,PrD);
}
}
/*
* Standard FFT for n a power of two, root a n-th primitive root of unity.
*/
template<class T,class P>
void FFT_Iter(vector<T>& ioput, int n, const T& root, const P& PrD)
{
int i, j, m;
T t;
// Bit-reversal of input
for( i = j = 0; i < n; ++i )
{
if( j >= i )
{
t = ioput[i];
ioput[i] = ioput[j];
ioput[j] = t;
}
m = n / 2;
while( (m >= 1) && (j >= m) )
{
j -= m;
m /= 2;
}
j += m;
}
T u, alpha, alpha2;
m = 0; j = 0; i = 0;
// Do the transform
for (int s = 1; s < n; s = 2*s)
{
m = 2*s;
Power(alpha, root, n/m, PrD);
assignOne(alpha2,PrD);
for (int j = 0; j < m/2; ++j)
{
//root = root_table[j*n/m];
for (int k = j; k < n; k += m)
{
Mul(t, alpha2, ioput[k + m/2], PrD);
u = ioput[k];
Add(ioput[k], u, t, PrD);
Sub(ioput[k + m/2], u, t, PrD);
}
Mul(alpha2, alpha2, alpha, PrD);
}
}
}
/*
* FFT modulo x^n + 1.
*
* n must be a power of two, root a 2n-th primitive root of unity.
*/
void FFT_Iter2(vector<modp>& ioput, int n, const modp& root, const Zp_Data& PrD)
{
FFT_Iter(ioput, n, root, PrD, false);
}
void FFT_Iter2(vector<modp>& ioput, int n, const vector<modp>& roots,
const Zp_Data& PrD)
{
FFT_Iter(ioput, n, roots, PrD, false);
}
void FFT_Iter(vector<modp>& ioput, int n, const modp& root, const Zp_Data& PrD,
bool start_with_one)
{
vector<modp> roots(n + 1);
assignOne(roots[0], PrD);
for (int i = 1; i < n + 1; i++)
Mul(roots[i], roots[i - 1], root, PrD);
FFT_Iter(ioput, n, roots, PrD, start_with_one);
}
void FFT_Iter(vector<modp>& ioput, int n, const vector<modp>& roots,
const Zp_Data& PrD, bool start_with_one)
{
assert(roots.size() > size_t(n));
int i, j, m;
// Bit-reversal of input
for( i = j = 0; i < n; ++i )
{
if( j >= i )
{
swap(ioput[i], ioput[j]);
}
m = n / 2;
while( (m >= 1) && (j >= m) )
{
j -= m;
m /= 2;
}
j += m;
}
m = 0; j = 0; i = 0;
// Do the transform
vector<modp> alpha2;
alpha2.reserve(n / 2);
for (int s = 1; s < n; s = 2*s)
{
m = 2*s;
alpha2.clear();
if (start_with_one)
{
for (int j = 0; j < m / 2; j++)
alpha2.push_back(roots[j * n / m]);
}
else
{
for (int j = 0; j < m / 2; j++)
alpha2.push_back(roots.at((j * 2 + 1) * (n / m)));
}
if (BaseMachine::thread_num == 0 and BaseMachine::has_singleton())
{
auto& queues = BaseMachine::s().queues;
FftJob job(ioput, alpha2, m, PrD);
int start = queues.distribute(job, n / 2);
for (int i = start; i < n / 2; i++)
FFT_Iter2_body(ioput, alpha2, i, m, PrD);
if (start > 0)
queues.wrap_up(job);
}
else
for (int i = 0; i < n / 2; i++)
FFT_Iter2_body(ioput, alpha2, i, m, PrD);
}
}
/* This does FFT for X^N+1,
Input and output is an array of size N (shared)
alpha is assumed to be a generator of the N'th roots of unity mod p
Starts at w=alpha and updates by alpha^2
*/
void FFT2(vector<modp>& a, int N, const modp& alpha, const Zp_Data& PrD)
{
int i;
if (N==1) { return; }
vector<modp> a0(N/2),a1(N/2);
for (i=0; i<N/2; i++)
{ a0[i]=a[2*i];
a1[i]=a[2*i+1];
}
modp w,alpha2,temp;
Sqr(alpha2,alpha,PrD);
FFT2(a0,N/2,alpha2,PrD); FFT2(a1,N/2,alpha2,PrD);
w=alpha;
for (i=0; i<N/2; i++)
{ Mul(temp,w,a1[i],PrD);
Add(a[i],a0[i],temp,PrD);
Sub(a[i+N/2],a0[i],temp,PrD);
Mul(w,w,alpha2,PrD);
}
}
void FFT_non_power_of_two(vector<modp>& res, const vector<modp>& input, const FFT_Data& FFTD)
{
vector<modp> tmp(FFTD.m());
BFFT(tmp, input, FFTD);
for (int i = 0; i < (FFTD).phi_m(); i++)
res[i] = tmp[(FFTD).p(i)];
}
void BFFT(vector<modp>& ans,const vector<modp>& a,const FFT_Data& FFTD,bool forward)
{
int k2=FFTD.twop,n=FFTD.m();
if (k2<0) { k2=-k2; }
int r=0;
if (forward==false) { r=1; }
if (FFTD.twop>0)
{ vector<modp> x(k2);
for (unsigned int i=0; i<a.size(); i++)
{ Mul(x[i],FFTD.powers[r][i],a[i],FFTD.get_prD()); }
for (int i=a.size(); i<k2; i++)
{ assignZero(x[i],FFTD.get_prD()); }
FFT_Iter(x,k2,FFTD.two_root[0],FFTD.get_prD());
for (int i=0; i<k2; i++)
{ Mul(x[i],x[i],FFTD.b[r][i],FFTD.get_prD()); }
FFT_Iter(x,k2,FFTD.two_root[1],FFTD.get_prD());
for (int i=0; i<n; i++)
{ Mul(ans[i],x[i+n-1],FFTD.powers_i[r][i],FFTD.get_prD()); }
}
else
{ throw crash_requested(); }
}