/* * BaseMachine.cpp * */ #include "BaseMachine.h" #include "OnlineOptions.h" #include "Math/Setup.h" #include "Tools/Bundle.h" #include "Instruction.hpp" #include "Protocols/ShuffleSacrifice.hpp" #include #include using namespace std; BaseMachine* BaseMachine::singleton = 0; thread_local int BaseMachine::thread_num; thread_local OnDemandOTTripleSetup BaseMachine::ot_setup; thread_local const Program* BaseMachine::program = 0; void print_usage(ostream& o, const char* name, size_t capacity) { if (capacity) o << name << "=" << capacity << " "; } BaseMachine& BaseMachine::s() { if (singleton) return *singleton; else throw runtime_error("no BaseMachine singleton"); } bool BaseMachine::has_program() { return has_singleton() and not s().progs.empty(); } DataPositions BaseMachine::get_offline_data_used() { if (program) return program->get_offline_data_used(); else return s().progs[0].get_offline_data_used(); } int BaseMachine::edabit_bucket_size(int n_bits) { size_t usage = 0; if (has_program()) usage = get_offline_data_used().total_edabits(n_bits); return bucket_size(usage); } int BaseMachine::triple_bucket_size(DataFieldType type) { size_t usage = 0; if (has_program()) usage = get_offline_data_used().files[type][DATA_TRIPLE]; return bucket_size(usage); } int BaseMachine::bucket_size(size_t usage) { int res = OnlineOptions::singleton.bucket_size; if (usage) { for (int B = res; B <= 5; B++) if (ShuffleSacrifice(B).minimum_n_outputs() < usage * .9) break; else res = B; } return res; } int BaseMachine::matrix_batch_size(int n_rows, int n_inner, int n_cols) { int limit = max(1., 1e6 / (max(n_rows * n_inner, n_inner * n_cols))); unsigned res = min(limit, OnlineOptions::singleton.batch_size); if (has_program()) res = min(res, (unsigned) matrix_requirement(n_rows, n_inner, n_cols)); return res; } int BaseMachine::matrix_requirement(int n_rows, int n_inner, int n_cols) { if (has_program()) { auto res = get_offline_data_used().matmuls[ {n_rows, n_inner, n_cols}]; if (res) return res; else return -1; } else return -1; } BaseMachine::BaseMachine() : nthreads(0), multithread(false), nan_warning(0) { if (sodium_init() == -1) throw runtime_error("couldn't initialize libsodium"); if (not singleton) singleton = this; } void BaseMachine::load_schedule(const string& progname, bool load_bytecode) { this->progname = progname; string fname = "Programs/Schedules/" + progname + ".sch"; #ifdef DEBUG_FILES cerr << "Opening file " << fname << endl; #endif ifstream inpf; inpf.open(fname); if (inpf.fail()) { throw file_error("Missing '" + fname + "'. Did you compile '" + progname + "'?"); } int nprogs; inpf >> nthreads; inpf >> nprogs; if (inpf.fail()) throw file_error("Error reading " + fname); #ifdef DEBUG_FILES cerr << "Number of threads I will run in parallel = " << nthreads << endl; cerr << "Number of program sequences I need to load = " << nprogs << endl; #endif bc_filenames.clear(); // Load in the programs string threadname; for (int i=0; i> threadname; size_t split = threadname.find_last_of(":"); long expected = -1; if (split != string::npos) { expected = atoi(threadname.substr(split + 1).c_str()); threadname = threadname.substr(0, split); } string filename = "Programs/Bytecode/" + threadname + ".bc"; bc_filenames.push_back(filename); if (load_bytecode) { #ifdef DEBUG_FILES cerr << "Loading program " << i << " from " << filename << endl; #endif long size = load_program(threadname, filename); if (expected >= 0 and expected != size) { stringstream os; os << "broken bytecode file, found " << size << " instructions, expected " << expected; throw runtime_error(os.str()); } } } for (auto i : {1, 0, 0}) { int n; inpf >> n; if (n != i) throw runtime_error("old schedule format not supported"); } inpf.get(); getline(inpf, compiler); getline(inpf, domain); getline(inpf, relevant_opts); getline(inpf, security); getline(inpf, gf2n); inpf.close(); } void BaseMachine::print_compiler() { if (compiler.size() != 0 and OnlineOptions::singleton.verbose) cerr << "Compiler: " << compiler << endl; } size_t BaseMachine::load_program(const string& threadname, const string& filename) { (void)threadname; (void)filename; throw not_implemented(); } void BaseMachine::time() { cout << "Elapsed time: " << timer[0].elapsed() << endl; } void BaseMachine::start(int n) { cout << "Starting timer " << n << " at " << timer[n].elapsed() << " (" << timer[n] << ")" << " after " << timer[n].idle() << endl; timer[n].start(total_comm()); } void BaseMachine::stop(int n) { timer[n].stop(total_comm()); cout << "Stopped timer " << n << " at " << timer[n].elapsed() << " (" << timer[n] << ")" << endl; } void BaseMachine::print_timers() { cerr << "The following benchmarks are "; if (OnlineOptions::singleton.live_prep) cerr << "in"; else cerr << "ex"; cerr << "cluding preprocessing (offline phase)." << endl; cerr << "Time = " << timer[0].elapsed() << " seconds " << endl; timer.erase(0); for (auto it = timer.begin(); it != timer.end(); it++) cerr << "Time" << it->first << " = " << it->second.elapsed() << " seconds (" << it->second << ")" << endl; } string BaseMachine::memory_filename(const string& type_short, int my_number) { return PREP_DIR "Memory-" + type_short + "-P" + to_string(my_number); } string BaseMachine::get_domain(string progname) { return get_basics(progname).domain; } BaseMachine BaseMachine::get_basics(string progname) { if (singleton and s().progname == progname) return s(); auto backup = singleton; BaseMachine machine; singleton = backup; machine.load_schedule(progname, false); return machine; } int BaseMachine::ring_size_from_schedule(string progname) { string domain = get_domain(progname); if (domain.substr(0, 2).compare("R:") == 0) { return stoi(domain.substr(2)); } else return 0; } int BaseMachine::prime_length_from_schedule(string progname) { string domain = get_domain(progname); if (domain.substr(0, 4).compare("lgp:") == 0) return stoi(domain.substr(4)); else return 0; } int BaseMachine::gf2n_length_from_schedule(string progname) { string domain = get_basics(progname).gf2n; if (domain.substr(0, 4).compare("lg2:") == 0) return stoi(domain.substr(4)); else return 0; } bigint BaseMachine::prime_from_schedule(string progname) { string domain = get_domain(progname); if (domain.substr(0, 2).compare("p:") == 0) return bigint(domain.substr(2)); else return 0; } int BaseMachine::security_from_schedule(string progname) { string sec = get_basics(progname).security; if (sec.substr(0, 4).compare("sec:") == 0) return stoi(sec.substr(4)); else return 0; } NamedCommStats BaseMachine::total_comm() { return queues.total_comm(); } void BaseMachine::set_thread_comm(const NamedCommStats& stats) { auto queue = queues.at(BaseMachine::thread_num); assert(queue); queue->set_comm_stats(stats); } void BaseMachine::print_global_comm(Player& P, const NamedCommStats& stats) { Bundle bundle(P); bundle.mine.store(stats.sent); P.Broadcast_Receive_no_stats(bundle); size_t global = 0; for (auto& os : bundle) global += os.get_int(8); cerr << "Global data sent = " << global / 1e6 << " MB (all parties)" << endl; } void BaseMachine::print_comm(Player& P, const NamedCommStats& comm_stats) { size_t rounds = 0; for (auto& x : comm_stats) rounds += x.second.rounds; cerr << "Data sent = " << comm_stats.sent / 1e6 << " MB in ~" << rounds << " rounds (party " << P.my_num() << " only"; if (multithread) cerr << "; rounds counted double due to multi-threading"; if (not OnlineOptions::singleton.verbose) cerr << "; use '-v' for more details"; cerr << ")" << endl; print_global_comm(P, comm_stats); }