#ifndef _Data_Files #define _Data_Files /* This class holds the Online data files all in one place * so the streams are easy to pass around and access */ #include "Math/field_types.h" #include "Tools/Buffer.h" #include "Processor/InputTuple.h" #include "Tools/Lock.h" #include "Networking/Player.h" #include "Protocols/edabit.h" #include "PrepBase.h" #include #include using namespace std; template class dabit; class DataTag { int t[4]; public: // assume that tag is three integers DataTag(const int* tag) { strncpy((char*)t, (char*)tag, 3 * sizeof(int)); t[3] = 0; } string get_string() const { return string((char*)t); } bool operator<(const DataTag& other) const { for (int i = 0; i < 3; i++) if (t[i] != other.t[i]) return t[i] < other.t[i]; return false; } }; class DataPositions { void process_line(long long items_used, const char* name, ifstream& file, bool print_verbose, double& total_cost, bool& reading_field, string suffix = "") const; public: static const char* dtype_names[]; static const char* field_names[N_DATA_FIELD_TYPE]; static const int tuple_size[N_DTYPE]; array, N_DATA_FIELD_TYPE> files; vector< array > inputs; array, N_DATA_FIELD_TYPE> extended; map, long long> edabits; map, long long> matmuls; DataPositions(int num_players = 0); DataPositions(const Player& P) : DataPositions(P.num_players()) {} ~DataPositions(); void reset(); void set_num_players(int num_players); int num_players() { return inputs.size(); } void count(DataFieldType type, DataTag tag, int n = 1); void count_edabit(bool strict, int n_bits); void increase(const DataPositions& delta); DataPositions& operator-=(const DataPositions& delta); DataPositions operator-(const DataPositions& delta) const; void print_cost() const; bool empty() const; bool any_more(const DataPositions& other) const; }; template class Processor; template class Data_Files; template class Machine; template class SubProcessor; template class Preprocessing : public PrepBase { protected: DataPositions& usage; map, vector>> edabits; map, edabitvec> my_edabits; bool do_count; void count(Dtype dtype, int n = 1) { usage.files[T::clear::field_type()][dtype] += do_count * n; } void count_input(int player) { usage.inputs[player][T::clear::field_type()] += do_count; } template void get_edabits(bool strict, size_t size, T* a, vector& Sb, const vector& regs, false_type); template void get_edabits(bool, size_t, T*, vector&, const vector&, true_type) { throw not_implemented(); } T get_random_from_inputs(int nplayers); public: template static Preprocessing* get_new(Machine& machine, DataPositions& usage, SubProcessor* proc); static Preprocessing* get_live_prep(SubProcessor* proc, DataPositions& usage); Preprocessing(DataPositions& usage) : usage(usage), do_count(true) {} virtual ~Preprocessing() {} virtual void set_protocol(typename T::Protocol& protocol) = 0; virtual void set_proc(SubProcessor* proc) { (void) proc; } virtual void seekg(DataPositions& pos) { (void) pos; } virtual void prune() {} virtual void purge() {} virtual size_t data_sent() { return comm_stats().sent; } virtual NamedCommStats comm_stats() { return {}; } virtual void get_three_no_count(Dtype dtype, T& a, T& b, T& c) = 0; virtual void get_two_no_count(Dtype dtype, T& a, T& b) = 0; virtual void get_one_no_count(Dtype dtype, T& a) = 0; virtual void get_input_no_count(T& a, typename T::open_type& x, int i) = 0; virtual void get_no_count(vector& S, DataTag tag, const vector& regs, int vector_size) = 0; void get(Dtype dtype, T* a); void get_three(Dtype dtype, T& a, T& b, T& c); void get_two(Dtype dtype, T& a, T& b); void get_one(Dtype dtype, T& a); void get_input(T& a, typename T::open_type& x, int i); void get(vector& S, DataTag tag, const vector& regs, int vector_size); virtual array get_triple(int n_bits); virtual array get_triple_no_count(int n_bits); virtual T get_bit(); virtual T get_random(); virtual void get_dabit(T&, typename T::bit_type&); virtual void get_dabit_no_count(T&, typename T::bit_type&) { throw runtime_error("no daBit"); } virtual void get_edabits(bool strict, size_t size, T* a, vector& Sb, const vector& regs) { get_edabits<0>(strict, size, a, Sb, regs, T::clear::characteristic_two); } template void get_edabit_no_count(bool, int n_bits, edabit& eb); template edabitvec get_edabitvec(bool strict, int n_bits); virtual void buffer_edabits_with_queues(bool, int) { throw runtime_error("no edaBits"); } virtual void push_triples(const vector>&) { throw runtime_error("no pushing"); } virtual void buffer_triples() {} virtual void buffer_inverses() {} virtual Preprocessing& get_part() { throw runtime_error("no part"); } }; template class Sub_Data_Files : public Preprocessing { template friend class Sub_Data_Files; static int tuple_length(int dtype); BufferOwner buffers[N_DTYPE]; vector> input_buffers; BufferOwner, RefInputTuple> my_input_buffers; map > extended; BufferOwner, dabit> dabit_buffer; map edabit_buffers; int my_num,num_players; const string prep_data_dir; int thread_num; Sub_Data_Files* part; void buffer_edabits_with_queues(bool strict, int n_bits) { buffer_edabits_with_queues<0>(strict, n_bits, T::clear::characteristic_two); } template void buffer_edabits_with_queues(bool strict, int n_bits, false_type); template void buffer_edabits_with_queues(bool, int, true_type) { throw not_implemented(); } public: static string get_filename(const Names& N, Dtype type, int thread_num = -1); static string get_input_filename(const Names& N, int input_player, int thread_num = -1); static string get_edabit_filename(const Names& N, int n_bits, int thread_num = -1); Sub_Data_Files(int my_num, int num_players, const string& prep_data_dir, DataPositions& usage, int thread_num = -1); Sub_Data_Files(const Names& N, DataPositions& usage, int thread_num = -1); Sub_Data_Files(const Names& N, const string& prep_data_dir, DataPositions& usage, int thread_num = -1) : Sub_Data_Files(N.my_num(), N.num_players(), prep_data_dir, usage, thread_num) { } ~Sub_Data_Files(); void set_protocol(typename T::Protocol& protocol) { (void) protocol; } void seekg(DataPositions& pos); void prune(); void purge(); bool eof(Dtype dtype); bool input_eof(int player); void get_no_count(Dtype dtype, T* a); void get_three_no_count(Dtype dtype, T& a, T& b, T& c) { buffers[dtype].input(a); buffers[dtype].input(b); buffers[dtype].input(c); } void get_two_no_count(Dtype dtype, T& a, T& b) { buffers[dtype].input(a); buffers[dtype].input(b); } void get_one_no_count(Dtype dtype, T& a) { buffers[dtype].input(a); } void get_input_no_count(T& a,typename T::open_type& x,int i) { RefInputTuple tuple(a, x); if (i==my_num) my_input_buffers.input(tuple); else input_buffers[i].input(a); } void setup_extended(const DataTag& tag, int tuple_size = 0); void get_no_count(vector& S, DataTag tag, const vector& regs, int vector_size); void get_dabit_no_count(T& a, typename T::bit_type& b); Preprocessing& get_part(); }; template class Data_Files { friend class Processor; DataPositions usage, skipped; public: Preprocessing& DataFp; Preprocessing& DataF2; Data_Files(Machine& machine, SubProcessor* procp = 0, SubProcessor* proc2 = 0); Data_Files(const Names& N); ~Data_Files(); DataPositions tellg(); void seekg(DataPositions& pos); void skip(const DataPositions& pos); void prune(); void purge(); DataPositions get_usage() { return usage - skipped; } void reset_usage() { usage.reset(); skipped.reset(); } NamedCommStats comm_stats() { return DataFp.comm_stats() + DataF2.comm_stats(); } }; template inline bool Sub_Data_Files::eof(Dtype dtype) { return buffers[dtype].eof; } template inline bool Sub_Data_Files::input_eof(int player) { if (player == my_num) return my_input_buffers.eof; else return input_buffers[player].eof; } template inline void Sub_Data_Files::get_no_count(Dtype dtype, T* a) { for (int i = 0; i < DataPositions::tuple_size[dtype]; i++) buffers[dtype].input(a[i]); } template inline void Preprocessing::get(Dtype dtype, T* a) { switch (dtype) { case DATA_TRIPLE: get_three(dtype, a[0], a[1], a[2]); break; case DATA_SQUARE: case DATA_INVERSE: get_two(dtype, a[0], a[1]); break; case DATA_BIT: get_one(dtype, a[0]); break; default: throw runtime_error("unsupported data type: " + to_string(dtype)); } } template inline void Preprocessing::get_three(Dtype dtype, T& a, T& b, T& c) { // count bit triples in get_triple() if (T::clear::field_type() != DATA_GF2) count(dtype); get_three_no_count(dtype, a, b, c); } template inline void Preprocessing::get_two(Dtype dtype, T& a, T& b) { count(dtype); get_two_no_count(dtype, a, b); } template inline void Preprocessing::get_one(Dtype dtype, T& a) { count(dtype); get_one_no_count(dtype, a); } template inline void Preprocessing::get_input(T& a, typename T::open_type& x, int i) { count_input(i); get_input_no_count(a, x, i); } template inline void Preprocessing::get(vector& S, DataTag tag, const vector& regs, int vector_size) { usage.count(T::clear::field_type(), tag, vector_size); get_no_count(S, tag, regs, vector_size); } template array Preprocessing::get_triple(int n_bits) { if (T::clear::field_type() == DATA_GF2) count(DATA_TRIPLE, n_bits); return get_triple_no_count(n_bits); } template array Preprocessing::get_triple_no_count(int n_bits) { assert(T::clear::field_type() != DATA_GF2 or T::default_length == 1 or T::default_length == n_bits or not do_count); array res; get(DATA_TRIPLE, res.data()); return res; } template T Preprocessing::get_bit() { T res; get_one(DATA_BIT, res); return res; } template T Preprocessing::get_random() { return get_random_from_inputs(usage.inputs.size()); } template inline void Data_Files::purge() { DataFp.purge(); DataF2.purge(); } #endif