Files
MP-SPDZ/Protocols/FakeProtocol.h
2023-12-14 12:17:54 +11:00

369 lines
10 KiB
C++

/*
* FakeProtocol.h
*
*/
#ifndef PROTOCOLS_FAKEPROTOCOL_H_
#define PROTOCOLS_FAKEPROTOCOL_H_
#include "Replicated.h"
#include "SecureShuffle.h"
#include "Math/Z2k.h"
#include "Processor/Instruction.h"
#include "Processor/TruncPrTuple.h"
#include <cmath>
template<class T>
class FakeShuffle
{
public:
typedef ShuffleStore<int> store_type;
FakeShuffle(SubProcessor<T>&)
{
}
FakeShuffle(vector<T>& a, size_t n, int unit_size, size_t output_base,
size_t input_base, SubProcessor<T>&)
{
apply(a, n, unit_size, output_base, input_base, 0, 0);
}
size_t generate(size_t, store_type& store)
{
return store.add();
}
void apply(vector<T>& a, size_t n, int unit_size, size_t output_base,
size_t input_base, int, bool)
{
auto source = a.begin() + input_base;
auto dest = a.begin() + output_base;
for (size_t i = 0; i < n; i++)
// just copy
*dest++ = *source++;
if (n > 1)
{
// swap first two to pass check
for (int i = 0; i < unit_size; i++)
swap(a[output_base + i], a[output_base + i + unit_size]);
}
}
void inverse_permutation(vector<T>&, size_t, size_t, size_t)
{
}
};
template<class T>
class FakeProtocol : public ProtocolBase<T>
{
PointerVector<T> results;
SeededPRNG G;
T dot_prod;
T trunc_max;
int fails;
vector<size_t> trunc_stats;
map<string, size_t> cisc_stats;
map<int, size_t> ltz_stats;
public:
typedef FakeShuffle<T> Shuffler;
Player& P;
FakeProtocol(Player& P) :
fails(0), trunc_stats(T::MAX_N_BITS + 1), P(P)
{
}
#ifdef VERBOSE
~FakeProtocol()
{
output_trunc_max<0>(T::invertible);
double expected = 0;
for (int i = 0; i <= T::MAX_N_BITS; i++)
{
if (trunc_stats[i] != 0)
cerr << i << ": " << trunc_stats[i] << endl;
expected += trunc_stats[i] * exp2(i - T::MAX_N_BITS);
}
if (expected != 0)
cerr << "Expected truncation failures: " << expected << endl;
for (auto& x : cisc_stats)
{
cerr << x.second << " " << x.first << endl;
}
for (auto& x : ltz_stats)
cerr << "LTZ " << x.first << ": " << x.second << endl;
}
template<int>
void output_trunc_max(false_type)
{
if (trunc_max != T())
cerr << "Maximum bit length in truncation: "
<< (bigint(typename T::clear(trunc_max)).numBits() + 1)
<< " (" << trunc_max << ")" << endl;
}
template<int>
void output_trunc_max(true_type)
{
}
#endif
FakeProtocol branch()
{
return P;
}
void init_mul()
{
results.clear();
}
void prepare_mul(const T& x, const T& y, int = -1)
{
results.push_back(x * y);
}
void exchange()
{
}
T finalize_mul(int = -1)
{
return results.next();
}
void init_dotprod()
{
init_mul();
dot_prod = {};
}
void prepare_dotprod(const T& x, const T& y)
{
dot_prod += x * y;
}
void next_dotprod()
{
results.push_back(dot_prod);
dot_prod = 0;
}
T finalize_dotprod(int)
{
return finalize_mul();
}
void randoms(T& res, int n_bits)
{
res.randomize_part(G, n_bits);
}
int get_n_relevant_players()
{
return 1;
}
void trunc_pr(const vector<int>& regs, int size, SubProcessor<T>& proc)
{
trunc_pr<0>(regs, size, proc, T::characteristic_two);
}
template<int>
void trunc_pr(const vector<int>&, int, SubProcessor<T>&, true_type)
{
throw not_implemented();
}
template<int>
void trunc_pr(const vector<int>& regs, int size, SubProcessor<T>& proc, false_type)
{
this->trunc_rounds++;
this->trunc_pr_counter += size * regs.size() / 4;
for (size_t i = 0; i < regs.size(); i += 4)
for (int l = 0; l < size; l++)
{
auto& res = proc.get_S_ref(regs[i] + l);
auto& source = proc.get_S_ref(regs[i + 1] + l);
T tmp = source;
tmp = tmp < T() ? (T() - tmp) : tmp;
trunc_max = max(trunc_max, tmp);
#ifdef TRUNC_PR_EMULATION_STATS
trunc_stats.at(tmp == T() ? 0 : tmp.bit_length())++;
#endif
#ifdef CHECK_BOUNDS_IN_TRUNC_PR_EMULATION
auto test = (source >> (regs[i + 2]));
if (test != 0 and test != T(-1) >> regs[i + 2])
{
cerr << typename T::clear(source) << " has more than "
<< regs[i + 2]
<< " bits in " << regs[i + 3]
<< "-bit truncation (test value "
<< typename T::clear(test) << ")" << endl;
fails++;
if (fails > 1000)
throw runtime_error("trunc_pr overflow");
}
#endif
int n_shift = regs[i + 3];
#ifdef ROUND_NEAREST_IN_EMULATION
res = source >> n_shift;
if (n_shift > 0)
{
bool overflow = T(source >> (n_shift - 1)).get_bit(0);
res += overflow;
}
#else
if (TruncPrTupleWithGap<typename T::clear>(regs, i).big_gap())
{
T r;
r.randomize(G);
if (source.negative())
res = -T(((-source + r) >> n_shift) - (r >> n_shift));
else
res = ((source + r) >> n_shift) - (r >> n_shift);
}
else
{
T r;
r.randomize_part(G, n_shift);
res = (source + r) >> n_shift;
}
#endif
}
}
void cisc(SubProcessor<T>& processor, const Instruction& instruction)
{
cisc(processor, instruction, T::characteristic_two);
}
template<int = 0>
void cisc(SubProcessor<T>&, const Instruction&, true_type)
{
throw not_implemented();
}
template<int = 0>
void cisc(SubProcessor<T>& processor, const Instruction& instruction, false_type)
{
int r0 = instruction.get_r(0);
string tag((char*)&r0, 4);
cisc_stats[tag.c_str()]++;
auto& args = instruction.get_start();
if (tag == string("LTZ\0", 4))
{
for (size_t i = 0; i < args.size(); i += args[i])
{
ltz_stats[args[i + 4]] += args[i + 1];
assert(i + args[i] <= args.size());
assert(args[i] == 6);
for (int j = 0; j < args[i + 1]; j++)
{
auto& res = processor.get_S()[args[i + 2] + j];
res = T(processor.get_S()[args[i + 3] + j]).get_bit(
args[i + 4] - 1);
}
}
}
else if (tag == string("EQZ\0", 4))
{
for (size_t i = 0; i < args.size(); i += args[i])
{
assert(i + args[i] <= args.size());
assert(args[i] == 6);
for (int j = 0; j < args[i + 1]; j++)
{
auto& res = processor.get_S()[args[i + 2] + j];
res = processor.get_S()[args[i + 3] + j] == 0;
}
}
}
else if (tag == "Trun")
{
for (size_t i = 0; i < args.size(); i += args[i])
{
assert(i + args[i] <= args.size());
assert(args[i] == 8);
int k = args[i + 4];
int m = args[i + 5];
int s = args[i + 7];
assert((s == 0) or (s == 1));
for (int j = 0; j < args[i + 1]; j++)
{
auto& res = processor.get_S()[args[i + 2] + j];
res = ((T(processor.get_S()[args[i + 3] + j])
+ (T(s) << (k - 1))) >> m) - (T(s) << (k - m - 1));
}
}
}
else if (tag == "FPDi")
{
for (size_t i = 0; i < args.size(); i += args[i])
{
assert(i + args[i] <= args.size());
int f = args.at(i + 6);
for (int j = 0; j < args[i + 1]; j++)
{
auto& res = processor.get_S()[args[i + 2] + j];
mpf_class a[2];
for (int k = 0; k < 2; k++)
a[k] = bigint(typename T::clear(
processor.get_S()[args[i + 3 + k] + j]));
if (a[1] != 0)
res = bigint(a[0] / a[1] * exp2(f));
else
res = 0;
}
}
}
else if (tag == "exp2")
{
for (size_t i = 0; i < args.size(); i += args[i])
{
assert(i + args[i] <= args.size());
int f = args.at(i + 5);
for (int j = 0; j < args[i + 1]; j++)
{
auto& res = processor.get_S()[args[i + 2] + j];
auto a = bigint(typename T::clear(
processor.get_S()[args[i + 3] + j]));
res = bigint(round(exp2(mpf_class(a).get_d() / exp2(f) + f)));
}
}
}
else if (tag == "log2")
{
for (size_t i = 0; i < args.size(); i += args[i])
{
assert(i + args[i] <= args.size());
int f = args.at(i + 5);
for (int j = 0; j < args[i + 1]; j++)
{
auto& res = processor.get_S()[args[i + 2] + j];
auto a = bigint(typename T::clear(
processor.get_S()[args[i + 3] + j]));
res = bigint(round((log2(mpf_class(a).get_d()) - f) * exp2(f)));
}
}
}
else
throw runtime_error("unknown CISC instruction: " + tag);
}
};
#endif /* PROTOCOLS_FAKEPROTOCOL_H_ */