mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-10 14:08:09 -05:00
205 lines
3.7 KiB
C++
205 lines
3.7 KiB
C++
/*
|
|
* Z2k.cpp
|
|
*
|
|
*/
|
|
|
|
#ifndef MATH_Z2K_HPP_
|
|
#define MATH_Z2K_HPP_
|
|
|
|
#include <Math/Z2k.h>
|
|
#include "Math/Integer.h"
|
|
|
|
#include <math.h>
|
|
|
|
template<int K>
|
|
const int Z2<K>::N_BITS;
|
|
template<int K>
|
|
const int Z2<K>::N_BYTES;
|
|
template<int K>
|
|
const false_type Z2<K>::invertible;
|
|
|
|
template<int K>
|
|
void Z2<K>::reqbl(int n)
|
|
{
|
|
if (n < 0 && N_BITS != -(int)n)
|
|
{
|
|
throw Processor_Error(
|
|
"Program compiled for rings of length " + to_string(-n)
|
|
+ " but VM supports only "
|
|
+ to_string(N_BITS));
|
|
}
|
|
else if (n > 0)
|
|
{
|
|
throw Processor_Error("Program compiled for fields not rings, "
|
|
"run compile.py with '-R " + to_string(K) + "'");
|
|
}
|
|
}
|
|
|
|
template<int K>
|
|
bool Z2<K>::allows(Dtype dtype)
|
|
{
|
|
return Integer::allows(dtype);
|
|
}
|
|
|
|
template<int K>
|
|
void Z2<K>::specification(octetStream& os)
|
|
{
|
|
os.store(K);
|
|
}
|
|
|
|
template<int K>
|
|
Z2<K>::Z2(const bigint& x) : Z2()
|
|
{
|
|
auto mp = x.get_mpz_t();
|
|
memcpy(a, mp->_mp_d, min((size_t)N_BYTES, sizeof(mp_limb_t) * abs(mp->_mp_size)));
|
|
if (mp->_mp_size < 0)
|
|
*this = Z2<K>() - *this;
|
|
normalize_byte();
|
|
}
|
|
|
|
template<int K>
|
|
template<class T>
|
|
Z2<K>::Z2(const IntBase<T>& x) :
|
|
Z2((mp_limb_t)x.get())
|
|
{
|
|
}
|
|
|
|
template<int K>
|
|
bool Z2<K>::get_bit(int i) const
|
|
{
|
|
if (i < N_BITS)
|
|
return 1 & (a[i / N_LIMB_BITS] >> (i % N_LIMB_BITS));
|
|
else
|
|
return false;
|
|
}
|
|
|
|
template<int K>
|
|
int Z2<K>::bit_length() const
|
|
{
|
|
if (is_zero())
|
|
return 1;
|
|
size_t max_limb = 0;
|
|
for (int i = 1; i < N_WORDS; i++)
|
|
if (a[i] != 0)
|
|
max_limb = i;
|
|
return log2(mp_limb_t(a[max_limb])) + 1 + 64 * max_limb;
|
|
}
|
|
|
|
template<int K>
|
|
Z2<K> Z2<K>::operator&(const Z2<K>& other) const
|
|
{
|
|
Z2<K> res;
|
|
mpn_and_n(res.a, a, other.a, N_WORDS);
|
|
return res;
|
|
}
|
|
|
|
template<int K>
|
|
Z2<K> Z2<K>::operator^(const Z2<K>& other) const
|
|
{
|
|
Z2<K> res;
|
|
mpn_xor_n(res.a, a, other.a, N_WORDS);
|
|
return res;
|
|
}
|
|
|
|
template<int K>
|
|
Z2<K> Z2<K>::operator|(const Z2<K>& other) const
|
|
{
|
|
Z2<K> res;
|
|
mpn_ior_n(res.a, a, other.a, N_WORDS);
|
|
return res;
|
|
}
|
|
|
|
template<int K>
|
|
bool Z2<K>::operator==(const Z2<K>& other) const
|
|
{
|
|
#ifdef DEBUG_MPN
|
|
for (int i = 0; i < N_WORDS; i++)
|
|
cout << "cmp " << hex << a[i] << " " << other.a[i] << endl;
|
|
#endif
|
|
return mpn_cmp(a, other.a, N_WORDS) == 0;
|
|
}
|
|
|
|
template<int K>
|
|
Z2<K> Z2<K>::invert() const
|
|
{
|
|
if (get_bit(0) != 1)
|
|
throw division_by_zero();
|
|
|
|
Z2<K> res = 1;
|
|
for (int i = 0; i < K; i++)
|
|
{
|
|
res += Z2<K>((Z2<K>(1) - Z2<K>::Mul(*this, res)).get_bit(i)) << i;
|
|
}
|
|
return res;
|
|
}
|
|
|
|
template<int K>
|
|
Z2<K> Z2<K>::sqrRoot()
|
|
{
|
|
assert(a[0] % 8 == 1);
|
|
Z2<K> res = 1;
|
|
for (int i = 0; i < K - 1; i++)
|
|
{
|
|
res += Z2<K>((*this - Z2<K>::Mul(res, res)).get_bit(i + 1)) << i;
|
|
#ifdef DEBUG_SQR
|
|
cout << "sqr " << dec << i << " " << hex << res << " " << res * res << " " << (*this - res * res) << endl;
|
|
#endif
|
|
}
|
|
return res;
|
|
}
|
|
|
|
template<int K>
|
|
void Z2<K>::input(istream& s, bool human)
|
|
{
|
|
if (human)
|
|
{
|
|
s >> bigint::tmp;
|
|
*this = bigint::tmp;
|
|
}
|
|
else
|
|
s.read((char*)a, N_BYTES);
|
|
}
|
|
|
|
template<int K>
|
|
void Z2<K>::output(ostream& s, bool human) const
|
|
{
|
|
if (human)
|
|
{
|
|
bigint::tmp = *this;
|
|
s << bigint::tmp;
|
|
}
|
|
else
|
|
s.write((char*)a, N_BYTES);
|
|
}
|
|
|
|
template<int K>
|
|
SignedZ2<K>::SignedZ2(const Integer& other)
|
|
{
|
|
if (K == 64)
|
|
this->a[0] = other.get();
|
|
else
|
|
*this = SignedZ2<64>(other);
|
|
}
|
|
|
|
template <int K>
|
|
ostream& operator<<(ostream& o, const Z2<K>& x)
|
|
{
|
|
x.output(o, true);
|
|
return o;
|
|
}
|
|
|
|
template<int K>
|
|
istream& operator>>(istream& i, SignedZ2<K>& x)
|
|
{
|
|
auto& tmp = bigint::tmp;
|
|
i >> tmp;
|
|
if (tmp.numBits() > K + 1)
|
|
throw runtime_error(
|
|
tmp.get_str() + " out of range for signed " + to_string(K)
|
|
+ "-bit numbers");
|
|
x = tmp;
|
|
return i;
|
|
}
|
|
|
|
#endif
|