mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-08 21:18:03 -05:00
362 lines
7.9 KiB
C++
362 lines
7.9 KiB
C++
#include "Rq_Element.h"
|
|
#include "FHE_Keys.h"
|
|
#include "Tools/Exceptions.h"
|
|
|
|
#include "Math/modp.hpp"
|
|
|
|
Rq_Element::Rq_Element(const FHE_PK& pk) :
|
|
Rq_Element(pk.get_params().FFTD())
|
|
{
|
|
}
|
|
|
|
Rq_Element::Rq_Element(const vector<FFT_Data>& prd, RepType r0, RepType r1)
|
|
{
|
|
if (prd.size() > 0)
|
|
a.push_back({prd[0], r0});
|
|
if (prd.size() > 1)
|
|
a.push_back({prd[1], r1});
|
|
lev = n_mults();
|
|
}
|
|
|
|
void Rq_Element::set_data(const vector<FFT_Data>& prd)
|
|
{
|
|
a.resize(prd.size());
|
|
for(size_t i = 0; i < a.size(); i++)
|
|
a[i].set_data(prd[i]);
|
|
lev=n_mults();
|
|
}
|
|
|
|
void Rq_Element::assign_zero(const vector<FFT_Data>& prd)
|
|
{
|
|
set_data(prd);
|
|
assign_zero();
|
|
}
|
|
|
|
void Rq_Element::assign_zero()
|
|
{
|
|
for (int i=0; i<=lev; ++i)
|
|
a[i].assign_zero();
|
|
}
|
|
|
|
void Rq_Element::assign_one()
|
|
{
|
|
for (int i=0; i<=lev; ++i)
|
|
a[i].assign_one();
|
|
}
|
|
|
|
void Rq_Element::partial_assign(const Rq_Element& other)
|
|
{
|
|
lev=other.lev;
|
|
a.resize(other.a.size());
|
|
}
|
|
|
|
void Rq_Element::negate()
|
|
{
|
|
for (int i=0; i<=lev; ++i)
|
|
a[i].negate();
|
|
}
|
|
|
|
Rq_Element Rq_Element::mul_by_X_i(int i) const
|
|
{
|
|
Rq_Element res;
|
|
res.lev = lev;
|
|
res.a.clear();
|
|
for (auto& x : a)
|
|
{
|
|
auto tmp = x.mul_by_X_i(i);
|
|
res.a.push_back(tmp);
|
|
}
|
|
return res;
|
|
}
|
|
|
|
void add(Rq_Element& ans,const Rq_Element& ra,const Rq_Element& rb)
|
|
{
|
|
ans.partial_assign(ra, rb);
|
|
for (int i=0; i<=ans.lev; ++i)
|
|
add(ans.a[i],ra.a[i],rb.a[i]);
|
|
if (ans.lev == 0 && ans.n_mults() == 1) {
|
|
ans.a[1].partial_assign(ra.a[1]);
|
|
}
|
|
}
|
|
void sub(Rq_Element& ans,const Rq_Element& a,const Rq_Element& b)
|
|
{
|
|
ans.partial_assign(a, b);
|
|
for (int i=0; i<=ans.lev; ++i)
|
|
sub(ans.a[i],a.a[i],b.a[i]);
|
|
if (ans.lev == 0 && ans.n_mults() == 1) {
|
|
ans.a[1].partial_assign(a.a[1]);
|
|
}
|
|
}
|
|
|
|
void mul(Rq_Element& ans,const Rq_Element& a,const Rq_Element& b)
|
|
{
|
|
ans.partial_assign(a, b);
|
|
for (int i=0; i<=ans.lev; ++i)
|
|
mul(ans.a[i],a.a[i],b.a[i]);
|
|
if (ans.lev == 0 && ans.n_mults() == 1) {
|
|
ans.a[1].partial_assign(a.a[1]);
|
|
}
|
|
}
|
|
|
|
void mul(Rq_Element& ans,const Rq_Element& a,const bigint& b)
|
|
{
|
|
ans.partial_assign(a);
|
|
modp bp;
|
|
for (int i=0; i<=ans.lev; ++i)
|
|
{
|
|
to_modp(bp,b,a.a[i].get_prD());
|
|
mul(ans.a[i],a.a[i],bp);
|
|
}
|
|
}
|
|
|
|
void Rq_Element::randomize(PRNG& G,int l)
|
|
{
|
|
set_level(l);
|
|
for (int i=0; i<=lev; ++i)
|
|
a[i].randomize(G);
|
|
}
|
|
|
|
bool Rq_Element::equals(const Rq_Element& other) const
|
|
{
|
|
if (lev!=other.lev) { throw level_mismatch(); }
|
|
for (int i=0; i<=lev; ++i)
|
|
if (!a[i].equals(other.a[i])) return false;
|
|
return true;
|
|
}
|
|
|
|
|
|
vector<bigint> Rq_Element::to_vec_bigint() const
|
|
{
|
|
vector<bigint> v;
|
|
to_vec_bigint(v);
|
|
return v;
|
|
}
|
|
|
|
// Doing sort of CRT;
|
|
// result mod p0 = a[0]; result mod p1 = a[1]
|
|
void Rq_Element::to_vec_bigint(vector<bigint>& v) const
|
|
{
|
|
a[0].to_vec_bigint(v);
|
|
if (n_mults() == 0) {
|
|
bigint p0 = a[0].get_prime();
|
|
for (size_t i = 0; i < v.size(); ++i) {
|
|
if (v[i] > p0 / 2) {
|
|
v[i] = (v[i] - p0);
|
|
}
|
|
}
|
|
}
|
|
if (lev==1)
|
|
{ vector<bigint> v1;
|
|
a[1].to_vec_bigint(v1);
|
|
bigint p0=a[0].get_prime();
|
|
bigint p1=a[1].get_prime();
|
|
bigint p0i,lambda,Q=p0*p1;
|
|
invMod(p0i,p0%p1,p1);
|
|
for (unsigned int i=0; i<v.size(); i++)
|
|
{ lambda=((v1[i]-v[i])*p0i)%Q;
|
|
v[i]=(v[i]+p0*lambda)%Q;
|
|
}
|
|
}
|
|
}
|
|
|
|
ConversionIterator Rq_Element::get_iterator() const
|
|
{
|
|
if (lev != 0)
|
|
throw not_implemented();
|
|
return a[0].get_iterator();
|
|
}
|
|
|
|
bigint Rq_Element::infinity_norm() const
|
|
{
|
|
bigint Q = 1, ans = 0;
|
|
for (int i = 0; i <= n_mults(); ++i)
|
|
{
|
|
Q *= a[i].get_prime();
|
|
}
|
|
bigint t;
|
|
vector<bigint> te=to_vec_bigint();
|
|
for (unsigned int i=0; i<te.size(); i++)
|
|
{ // Take rounded value and then abs value
|
|
if (te[i]<Q/2) { t=te[i]; }
|
|
else { t=Q-te[i]; }
|
|
if (t>ans) { ans=t; }
|
|
}
|
|
return ans;
|
|
}
|
|
|
|
void Rq_Element::change_rep(RepType r)
|
|
{
|
|
if (lev==1) { throw level_mismatch(); }
|
|
a[0].change_rep(r);
|
|
}
|
|
|
|
void Rq_Element::change_rep(RepType r0,RepType r1)
|
|
{
|
|
if (lev==0 or n_mults() != 1) { throw level_mismatch(); }
|
|
a[0].change_rep(r0);
|
|
a[1].change_rep(r1);
|
|
}
|
|
|
|
void Rq_Element::Scale(const bigint& p)
|
|
{
|
|
if (lev==0) { return; }
|
|
if (n_mults() == 0) {
|
|
//for some reason we scale but we have just one level
|
|
throw level_mismatch();
|
|
}
|
|
bigint p0=a[0].get_prime(),p1=a[1].get_prime(),p1i,lambda,n=p1*p;
|
|
invMod(p1i,p1%p,p);
|
|
|
|
// First multiply input by [p1]_p
|
|
bigint te=p1%p;
|
|
if (te>p/2) { te-=p; }
|
|
modp tep;
|
|
to_modp(tep,te,a[0].get_prD());
|
|
mul(a[0],a[0],tep);
|
|
to_modp(tep,te,a[1].get_prD());
|
|
mul(a[1],a[1],tep);
|
|
|
|
// Now compute delta
|
|
Ring_Element b0(a[0].get_FFTD(),evaluation);
|
|
Ring_Element b1(a[1].get_FFTD(),evaluation);
|
|
// scope to ensure deconstruction of write iterators
|
|
{
|
|
auto poly_a1 = a[1];
|
|
poly_a1.change_rep(polynomial);
|
|
auto it = poly_a1.get_iterator();
|
|
auto it0 = b0.get_write_iterator();
|
|
auto it1 = b1.get_write_iterator();
|
|
bigint half_n = n / 2;
|
|
bigint delta;
|
|
for (int i=0; i < a[1].get_FFTD().phi_m(); i++)
|
|
{
|
|
it.get(delta);
|
|
lambda = delta;
|
|
lambda *= p1i;
|
|
lambda %= p;
|
|
lambda *= p1;
|
|
lambda -= delta;
|
|
lambda %= n;
|
|
if (lambda > half_n)
|
|
lambda -= n;
|
|
it0.get(lambda);
|
|
it1.get(lambda);
|
|
}
|
|
}
|
|
|
|
// Now add delta back onto a0
|
|
Rq_Element bb(b0,b1);
|
|
add(*this,*this,bb);
|
|
|
|
// Now divide by p1 mod p0
|
|
modp p1_inv,pp;
|
|
to_modp(pp,p1,a[0].get_prD());
|
|
Inv(p1_inv,pp,a[0].get_prD());
|
|
lev=0;
|
|
mul(a[0],a[0],p1_inv);
|
|
}
|
|
|
|
void Rq_Element::mul_by_p1()
|
|
{
|
|
|
|
if (n_mults() == 0) {throw level_mismatch();}
|
|
lev=1;
|
|
bigint m=a[1].get_prime()%a[0].get_prime();
|
|
modp mp;
|
|
to_modp(mp,m,a[0].get_prD());
|
|
mul(a[0],a[0],mp);
|
|
a[1].assign_zero();
|
|
}
|
|
|
|
void Rq_Element::raise_level()
|
|
{
|
|
if (lev==n_mults()) { return; }
|
|
lev=1;
|
|
a[1].from(a[0].get_copy_iterator());
|
|
}
|
|
|
|
void Rq_Element::check_level() const
|
|
{
|
|
if ((unsigned)lev > (unsigned)n_mults())
|
|
throw range_error(
|
|
"level out of range: " + to_string(lev) + "/" + to_string(n_mults()));
|
|
}
|
|
|
|
void Rq_Element::partial_assign(const Rq_Element& x, const Rq_Element& y)
|
|
{
|
|
x.check_level();
|
|
y.check_level();
|
|
if (x.lev != y.lev or x.n_mults() != y.n_mults())
|
|
throw level_mismatch();
|
|
partial_assign(x);
|
|
}
|
|
|
|
void Rq_Element::pack(octetStream& o) const
|
|
{
|
|
check_level();
|
|
o.store(lev);
|
|
for (int i = 0; i <= lev; ++i)
|
|
a[i].pack(o);
|
|
}
|
|
|
|
void Rq_Element::unpack(octetStream& o)
|
|
{
|
|
unsigned int ll; o.get(ll); lev=ll;
|
|
check_level();
|
|
for (int i = 0; i <= lev; ++i)
|
|
a[i].unpack(o);
|
|
}
|
|
|
|
void Rq_Element::output(ostream& s) const
|
|
{
|
|
check_level();
|
|
s.write((char*)&lev, sizeof(lev));
|
|
for (int i = 0; i <= lev; i++)
|
|
a[i].output(s);
|
|
}
|
|
|
|
void Rq_Element::input(istream& s)
|
|
{
|
|
s.read((char*)&lev, sizeof(lev));
|
|
check_level();
|
|
for (int i = 0; i <= lev; i++)
|
|
a[i].input(s);
|
|
}
|
|
|
|
void Rq_Element::check(const FHE_Params& params) const
|
|
{
|
|
if (n_mults() != params.n_mults())
|
|
throw level_mismatch();
|
|
for (int i = 0; i <= lev; i++)
|
|
a[i].check(params.FFTD()[i]);
|
|
}
|
|
|
|
size_t Rq_Element::report_size(ReportType type) const
|
|
{
|
|
size_t sz = a[0].report_size(type);
|
|
if (lev == 1 || type == CAPACITY)
|
|
if (n_mults() == 1)
|
|
sz += a[1].report_size(type);
|
|
return sz;
|
|
}
|
|
|
|
void Rq_Element::print_first_non_zero() const
|
|
{
|
|
vector<bigint> v = to_vec_bigint();
|
|
size_t i;
|
|
for (i = 0; i < v.size(); i++)
|
|
{
|
|
if (v[i] != 0)
|
|
{
|
|
cout << i << ":" << v[i];
|
|
break;
|
|
}
|
|
}
|
|
if (i == v.size())
|
|
cout << "ZERO" << endl;
|
|
cout << endl;
|
|
}
|
|
|
|
template void Rq_Element::from<bigint>(const Generator<bigint>&, int);
|
|
template void Rq_Element::from<int>(const Generator<int>&, int);
|