#include "Zp_Data.h" #include "mpn_fixed.h" #include "gf2nlong.h" void Zp_Data::init(const bigint& p,bool mont) { #ifdef VERBOSE if (pr != 0) { if (pr != p) cerr << "Changing prime from " << pr << " to " << p << endl; if (mont != montgomery) cerr << "Changing Montgomery" << endl; } #endif if (not probPrime(p)) throw runtime_error(p.get_str() + " is not a prime"); pr=p; pr_half = p / 2; mask=static_cast(1ULL<<((mpz_sizeinbase(pr.get_mpz_t(),2)-1)%(8*sizeof(mp_limb_t))))-1; pr_byte_length = numBytes(pr); pr_bit_length = numBits(pr); int k = pr_bit_length; overhang = (uint64_t(-1LL) >> (63 - (k - 1) % 64)); montgomery=mont; t=mpz_size(pr.get_mpz_t()); if (t>MAX_MOD_SZ) throw max_mod_sz_too_small(t); if (montgomery) { inline_mpn_zero(R,MAX_MOD_SZ); inline_mpn_zero(R2,MAX_MOD_SZ); inline_mpn_zero(R3,MAX_MOD_SZ); bigint r=2,pp=pr; mpz_pow_ui(r.get_mpz_t(),r.get_mpz_t(),t*8*sizeof(mp_limb_t)); mpz_invert(pp.get_mpz_t(),pr.get_mpz_t(),r.get_mpz_t()); pp=r-pp; // pi=-1/p mod R pi=(pp.get_mpz_t()->_mp_d)[0]; r=r%pr; mpn_copyi(R,r.get_mpz_t()->_mp_d,mpz_size(r.get_mpz_t())); bigint r2=(r*r)%pr; mpn_copyi(R2,r2.get_mpz_t()->_mp_d,mpz_size(r2.get_mpz_t())); bigint r3=(r2*r)%pr; mpn_copyi(R3,r3.get_mpz_t()->_mp_d,mpz_size(r3.get_mpz_t())); if (sizeof(unsigned long)!=sizeof(mp_limb_t)) { cout << "The underlying types of MPIR mean we cannot use our Montgomery code" << endl; throw not_implemented(); } } inline_mpn_zero(prA,MAX_MOD_SZ+1); mpn_copyi(prA,pr.get_mpz_t()->_mp_d,t); } #include void Zp_Data::Mont_Mult(mp_limb_t* z,const mp_limb_t* x,const mp_limb_t* y,int t) const { mp_limb_t ans[2 * MAX_MOD_SZ + 1], u, yy[t + 1]; inline_mpn_copyi(yy, y, t); yy[t] = 0; // First loop u=x[0]*y[0]*pi; ans[t] = mpn_mul_1(ans,y,t,x[0]); ans[t+1] = mpn_addmul_1(ans,prA,t+1,u); for (int i=1; i=pr) { ans=z-pr; } // else { z=ans; } if (mpn_cmp(ans+t,prA,t+1)>=0) { mpn_sub_n(z,ans+t,prA,t); } else { inline_mpn_copyi(z,ans+t,t); } } void Zp_Data::Mont_Mult_switch(mp_limb_t* z, const mp_limb_t* x, const mp_limb_t* y) const { switch (t) { #ifdef __BMI2__ #define CASE(N) \ case N: \ Mont_Mult_(z, x, y); \ break; CASE(1) CASE(2) #if MAX_MOD_SZ >= 4 CASE(3) CASE(4) #endif #if MAX_MOD_SZ >= 5 CASE(5) #endif #if MAX_MOD_SZ >= 6 CASE(6) #endif #if MAX_MOD_SZ >= 10 CASE(7) CASE(8) CASE(9) CASE(10) #endif #undef CASE #endif default: Mont_Mult_variable(z, x, y); break; } } ostream& operator<<(ostream& s,const Zp_Data& ZpD) { s << ZpD.pr << " " << ZpD.montgomery << endl; if (ZpD.montgomery) { s << ZpD.t << " " << ZpD.pi << endl; for (int i=0; i>(istream& s,Zp_Data& ZpD) { s >> ZpD.pr >> ZpD.montgomery; ZpD.init(ZpD.pr, ZpD.montgomery); if (ZpD.montgomery) { s >> ZpD.t >> ZpD.pi; if (ZpD.t>MAX_MOD_SZ) throw max_mod_sz_too_small(ZpD.t); inline_mpn_zero(ZpD.R,MAX_MOD_SZ); inline_mpn_zero(ZpD.R2,MAX_MOD_SZ); inline_mpn_zero(ZpD.R3,MAX_MOD_SZ); inline_mpn_zero(ZpD.prA,MAX_MOD_SZ+1); for (int i=0; i> ZpD.R[i]; } for (int i=0; i> ZpD.R2[i]; } for (int i=0; i> ZpD.R3[i]; } for (int i=0; i> ZpD.prA[i]; } } return s; } void Zp_Data::pack(octetStream& o) const { pr.pack(o); o.store((int)montgomery); } void Zp_Data::unpack(octetStream& o) { pr.unpack(o); int m; o.get(m); montgomery = m; if (pr != 0) init(pr, m); } bool Zp_Data::operator!=(const Zp_Data& other) const { if (pr != other.pr or montgomery != other.montgomery) return true; else return false; } bool Zp_Data::operator==(const Zp_Data& other) const { return not (*this != other); }