Files
MP-SPDZ/FHE/FFT.cpp

231 lines
5.3 KiB
C++

#include "FHE/FFT.h"
#include "Math/Zp_Data.h"
#include "Math/modp.hpp"
/* Computes the FFT via Horner's Rule
theta is assumed to be an Nth root of unity
*/
void NaiveFFT(vector<modp>& ans,vector<modp>& a,int N,const modp& theta,const Zp_Data& PrD)
{
int i,j;
modp thetaPow;
assignOne(thetaPow,PrD);
for (i=0; i<N; i++)
{ ans[i]=a[N-1];
for (j=N-2; j>=0; j--)
{ Mul(ans[i],ans[i],thetaPow,PrD);
Add(ans[i],ans[i],a[j],PrD);
}
Mul(thetaPow,thetaPow,theta,PrD);
}
}
void FFT(vector<modp>& a,int N,const modp& theta,const Zp_Data& PrD)
{
if (N==1) { return; }
if (N<5)
{ vector<modp> b(N);
NaiveFFT(b,a,N,theta,PrD);
a=b;
return;
}
vector<modp> a0(N/2),a1(N/2);
int i;
for (i=0; i<N/2; i++)
{ a0[i]=a[2*i];
a1[i]=a[2*i+1];
}
modp theta2,w,t;
Sqr(theta2,theta,PrD);
FFT(a0,N/2,theta2,PrD);
FFT(a1,N/2,theta2,PrD);
assignOne(w,PrD);
for (i=0; i<N/2; i++)
{ Mul(t,w,a1[i],PrD);
Add(a[i],a0[i],t,PrD);
Sub(a[i+N/2],a0[i],t,PrD);
Mul(w,w,theta,PrD);
}
}
/*
* Standard FFT for n a power of two, root a n-th primitive root of unity.
*/
template<class T,class P>
void FFT_Iter(vector<T>& ioput, int n, const T& root, const P& PrD)
{
int i, j, m;
T t;
// Bit-reversal of input
for( i = j = 0; i < n; ++i )
{
if( j >= i )
{
t = ioput[i];
ioput[i] = ioput[j];
ioput[j] = t;
}
m = n / 2;
while( (m >= 1) && (j >= m) )
{
j -= m;
m /= 2;
}
j += m;
}
T u, alpha, alpha2;
m = 0; j = 0; i = 0;
// Do the transform
for (int s = 1; s < n; s = 2*s)
{
m = 2*s;
Power(alpha, root, n/m, PrD);
assignOne(alpha2,PrD);
for (int j = 0; j < m/2; ++j)
{
//root = root_table[j*n/m];
for (int k = j; k < n; k += m)
{
Mul(t, alpha2, ioput[k + m/2], PrD);
u = ioput[k];
Add(ioput[k], u, t, PrD);
Sub(ioput[k + m/2], u, t, PrD);
}
Mul(alpha2, alpha2, alpha, PrD);
}
}
}
/*
* FFT modulo x^n + 1.
*
* n must be a power of two, root a 2n-th primitive root of unity.
*/
void FFT_Iter2(vector<modp>& ioput, int n, const modp& root, const Zp_Data& PrD)
{
int i, j, m;
modp t;
// Bit-reversal of input
for( i = j = 0; i < n; ++i )
{
if( j >= i )
{
t = ioput[i];
ioput[i] = ioput[j];
ioput[j] = t;
}
m = n / 2;
while( (m >= 1) && (j >= m) )
{
j -= m;
m /= 2;
}
j += m;
}
modp u, alpha, alpha2;
m = 0; j = 0; i = 0;
// Do the transform
for (int s = 1; s < n; s = 2*s)
{
m = 2*s;
Power(alpha, root, n/m, PrD);
alpha2 = alpha;
Mul(alpha, alpha, alpha, PrD);
for (int j = 0; j < m/2; ++j)
{
//root = root_table[(2*j+1)*n/m];
for (int k = j; k < n; k += m)
{
Mul(t, alpha2, ioput[k + m/2], PrD);
u = ioput[k];
Add(ioput[k], u, t, PrD);
Sub(ioput[k + m/2], u, t, PrD);
}
Mul(alpha2, alpha2, alpha, PrD);
}
}
}
/* This does FFT for X^N+1,
Input and output is an array of size N (shared)
alpha is assumed to be a generator of the N'th roots of unity mod p
Starts at w=alpha and updates by alpha^2
*/
void FFT2(vector<modp>& a, int N, const modp& alpha, const Zp_Data& PrD)
{
int i;
if (N==1) { return; }
vector<modp> a0(N/2),a1(N/2);
for (i=0; i<N/2; i++)
{ a0[i]=a[2*i];
a1[i]=a[2*i+1];
}
modp w,alpha2,temp;
Sqr(alpha2,alpha,PrD);
FFT2(a0,N/2,alpha2,PrD); FFT2(a1,N/2,alpha2,PrD);
w=alpha;
for (i=0; i<N/2; i++)
{ Mul(temp,w,a1[i],PrD);
Add(a[i],a0[i],temp,PrD);
Sub(a[i+N/2],a0[i],temp,PrD);
Mul(w,w,alpha2,PrD);
}
}
void FFT_non_power_of_two(vector<modp>& res, const vector<modp>& input, const FFT_Data& FFTD)
{
vector<modp> tmp(FFTD.m());
BFFT(tmp, input, FFTD);
for (int i = 0; i < (FFTD).phi_m(); i++)
res[i] = tmp[(FFTD).p(i)];
}
void BFFT(vector<modp>& ans,const vector<modp>& a,const FFT_Data& FFTD,bool forward)
{
int k2=FFTD.twop,n=FFTD.m();
if (k2<0) { k2=-k2; }
int r=0;
if (forward==false) { r=1; }
if (FFTD.twop>0)
{ vector<modp> x(k2);
for (unsigned int i=0; i<a.size(); i++)
{ Mul(x[i],FFTD.powers[r][i],a[i],FFTD.get_prD()); }
for (int i=a.size(); i<k2; i++)
{ assignZero(x[i],FFTD.get_prD()); }
FFT_Iter(x,k2,FFTD.two_root[0],FFTD.get_prD());
for (int i=0; i<k2; i++)
{ Mul(x[i],x[i],FFTD.b[r][i],FFTD.get_prD()); }
FFT_Iter(x,k2,FFTD.two_root[1],FFTD.get_prD());
for (int i=0; i<n; i++)
{ Mul(ans[i],x[i+n-1],FFTD.powers_i[r][i],FFTD.get_prD()); }
}
else
{ throw crash_requested(); }
}