Files
MP-SPDZ/Processor/BaseMachine.h
2025-12-24 13:47:42 +11:00

264 lines
6.7 KiB
C++

/*
* BaseMachine.h
*
*/
#ifndef PROCESSOR_BASEMACHINE_H_
#define PROCESSOR_BASEMACHINE_H_
#include "Tools/time-func.h"
#include "Tools/TimerWithComm.h"
#include "OT/OTTripleSetup.h"
#include "ThreadJob.h"
#include "ThreadQueues.h"
#include "Program.h"
#include "OnlineOptions.h"
#include <map>
#include <fstream>
using namespace std;
void print_usage(ostream& o, const char* name, size_t capacity);
class BaseMachine
{
friend class Program;
template<class sint, class sgf2n> friend class thread_info;
protected:
static BaseMachine* singleton;
static thread_local OnDemandOTTripleSetup ot_setup;
static thread_local const Program* program;
std::map<int,TimerWithComm> timer;
string compiler;
string domain;
string relevant_opts;
string security;
string gf2n;
string expected_communication;
NamedCommStats one_off_comm;
virtual size_t load_program(const string& threadname,
const string& filename);
static BaseMachine get_basics(string progname);
static DataPositions get_offline_data_used();
public:
static thread_local int thread_num;
string progname;
int nthreads;
bool multithread;
ThreadQueues queues;
vector<string> bc_filenames;
vector<Program> progs;
bool nan_warning;
int mini_warning;
static BaseMachine& s();
static bool has_singleton() { return singleton != 0; }
static bool has_program();
static string memory_filename(const string& type_short, int my_number);
static string get_domain(string progname);
static int ring_size_from_schedule(string progname);
static int prime_length_from_schedule(string progname);
static int gf2n_length_from_schedule(string progname);
static bigint prime_from_schedule(string progname);
static int security_from_schedule(string progname);
template<class T>
static int batch_size(Dtype type, int buffer_size = 0, int fallback = 0,
int factor = 0);
template<class T>
static int input_batch_size(int player, int buffer_size = 0);
template<class T>
static int edabit_batch_size(int n_bits, int buffer_size = 0);
static int edabit_bucket_size(int n_bits);
static int triple_bucket_size(DataFieldType type);
static int bucket_size(size_t usage);
static int matrix_batch_size(int n_rows, int n_inner, int n_cols);
static int matrix_requirement(int n_rows, int n_inner, int n_cols);
static bool allow_mulm();
static void add_one_off(const NamedCommStats& comm);
BaseMachine();
virtual ~BaseMachine() {}
void load_schedule(const string& progname, bool load_bytecode = true);
void print_compiler();
void time();
void start(int n);
void stop(int n);
void print_timers();
virtual void reqbl(int) {}
virtual void active(int) {}
static OTTripleSetup fresh_ot_setup(Player& P);
NamedCommStats total_comm();
void set_thread_comm(const NamedCommStats& stats);
void print_global_comm(Player& P, const NamedCommStats& stats);
void print_comm(Player& P, const NamedCommStats& stats);
virtual const Names& get_N() { throw not_implemented(); }
virtual void gap_warning(int) { throw not_implemented(); }
};
inline OTTripleSetup BaseMachine::fresh_ot_setup(Player& P)
{
return ot_setup.get_fresh(P);
}
template<class T>
int BaseMachine::batch_size(Dtype type, int buffer_size, int fallback,
int factor)
{
if (OnlineOptions::singleton.has_option("debug_batch_size"))
fprintf(stderr, "batch_size buffer_size=%d fallback=%d\n", buffer_size,
fallback);
int n_opts;
int n = 0;
int res = 0;
if (buffer_size > 0)
n_opts = buffer_size;
else if (fallback > 0)
n_opts = fallback;
else
n_opts = OnlineOptions::singleton.batch_size
* max(factor, T::default_length);
if (buffer_size <= 0 and has_program())
{
auto files = get_offline_data_used().files;
auto usage = files[T::clear::field_type()];
if (type == DATA_DABIT and T::LivePrep::bits_from_dabits())
n = usage[DATA_BIT] + usage[DATA_DABIT];
else if (type == DATA_BIT and T::LivePrep::dabits_from_bits())
n = usage[DATA_BIT] + usage[DATA_DABIT];
else
n = usage[type];
}
else if (type != DATA_DABIT)
{
n = buffer_size;
buffer_size = 0;
}
if (n > 0 and not (buffer_size > 0))
{
bool used_frac = false;
if (n > n_opts)
{
// finding the right fraction
for (int i = 1; i <= 10; i++)
{
int frac = DIV_CEIL(n, i);
if (frac <= n_opts)
{
res = frac;
used_frac = true;
#ifdef DEBUG_BATCH_SIZE
cerr << "found fraction " << frac << endl;
#endif
break;
}
}
}
if (not used_frac)
res = min(n, n_opts);
}
else
res = n_opts;
if (OnlineOptions::singleton.has_option("debug_batch_size"))
{
cerr << DataPositions::dtype_names[type] << " " << T::type_string()
<< " res=" << res << " n=" << n << " n_opts=" << n_opts
<< " buffer_size=" << buffer_size << " bits/dabits="
<< T::LivePrep::bits_from_dabits() << "/"
<< T::LivePrep::dabits_from_bits() << " has_program="
<< has_program();
if (program)
cerr << " program=" << program->get_name();
cerr << endl;
}
assert(res > 0);
return res;
}
template<class T>
int BaseMachine::input_batch_size(int player, int buffer_size)
{
if (buffer_size)
return buffer_size;
if (has_program())
{
auto res =
get_offline_data_used(
).inputs[player][T::clear::field_type()];
if (res > 0)
return res;
}
return OnlineOptions::singleton.batch_size;
}
template<class T>
int BaseMachine::edabit_batch_size(int n_bits, int buffer_size)
{
int n_opts;
int n = 0;
int res;
if (buffer_size > 0)
n_opts = buffer_size;
else
n_opts = OnlineOptions::singleton.batch_size;
if (has_program())
{
n = get_offline_data_used().total_edabits(n_bits);
}
if (n > 0 and not (buffer_size > 0))
res = min(n, n_opts);
else
res = n_opts;
#ifdef DEBUG_BATCH_SIZE
cerr << "edaBits " << T::type_string() << " (" << n_bits
<< ") res=" << res << " n="
<< n << " n_opts=" << n_opts << " buffer_size=" << buffer_size << endl;
#endif
assert(res > 0);
return res;
}
#endif /* PROCESSOR_BASEMACHINE_H_ */