Allow overwriting of persistence files.

This commit is contained in:
Marcel Keller
2022-01-12 20:11:28 +11:00
parent f343d73b25
commit 0f9d5de697
8 changed files with 70 additions and 23 deletions

View File

@@ -1727,14 +1727,15 @@ class writesharestofile(base.IOInstruction):
""" Write shares to ``Persistence/Transactions-P<playerno>.data``
(appending at the end).
:param: number of shares (int)
:param: number of arguments to follow / number of shares plus one (int)
:param: position (regint, -1 for appending)
:param: source (sint)
:param: (repeat from source)...
"""
__slots__ = []
code = base.opcodes['WRITEFILESHARE']
arg_format = itertools.repeat('s')
arg_format = tools.chain(['ci'], itertools.repeat('s'))
def has_var_args(self):
return True

View File

@@ -2329,16 +2329,20 @@ class sint(_secret, _int):
return stop, shares
@staticmethod
def write_to_file(shares):
def write_to_file(shares, position=None):
""" Write shares to ``Persistence/Transactions-P<playerno>.data``
(appending at the end).
:param: shares (list or iterable of sint)
:param shares: (list or iterable of sint)
:param position: start position (int/regint/cint),
defaults to end of file
"""
for share in shares:
assert isinstance(share, sint)
assert share.size == 1
writesharestofile(*shares)
if position is None:
position = -1
writesharestofile(regint.conv(position), *shares)
@vectorized_classmethod
def load_mem(cls, address, mem_type=None):
@@ -3922,13 +3926,15 @@ class _single(_number, _secret_structure):
return stop, [cls._new(x) for x in shares]
@classmethod
def write_to_file(cls, shares):
def write_to_file(cls, shares, position=None):
""" Write shares of integer representation to
``Persistence/Transactions-P<playerno>.data`` (appending at the end).
``Persistence/Transactions-P<playerno>.data``.
:param: shares (list or iterable of sfix)
:param shares: (list or iterable of sfix)
:param position: start position (int/regint/cint),
defaults to end of file
"""
cls.int_type.write_to_file([x.v for x in shares])
cls.int_type.write_to_file([x.v for x in shares], position)
def store_in_mem(self, address):
""" Store in memory by public address. """
@@ -5389,11 +5395,14 @@ class Array(_vectorizable):
self.assign(shares)
return stop
def write_to_file(self):
def write_to_file(self, position=None):
""" Write shares of integer representation to
``Persistence/Transactions-P<playerno>.data`` (appending at the end).
``Persistence/Transactions-P<playerno>.data``.
:param position: start position (int/regint/cint),
defaults to end of file
"""
self.value_type.write_to_file(list(self))
self.value_type.write_to_file(list(self), position)
def __add__(self, other):
""" Vector addition.
@@ -5723,13 +5732,20 @@ class SubMultiArray(_vectorizable):
def _(i):
self[i].input_from(player, budget=budget, raw=raw)
def write_to_file(self):
def write_to_file(self, position=None):
""" Write shares of integer representation to
``Persistence/Transactions-P<playerno>.data`` (appending at the end).
``Persistence/Transactions-P<playerno>.data``.
:param position: start position (int/regint/cint),
defaults to end of file
"""
@library.for_range(len(self))
def _(i):
self[i].write_to_file()
if position is None:
my_pos = None
else:
my_pos = position + i * self[i].total_size()
self[i].write_to_file(my_pos)
def read_from_file(self, start):
""" Read content from ``Persistence/Transactions-P<playerno>.data``.

View File

@@ -27,7 +27,8 @@ class Binary_File_IO
* Throws file_error.
*/
template <class T>
void write_to_file(const string filename, const vector< T >& buffer);
void write_to_file(const string filename, const vector<T>& buffer,
long start_pos);
/*
* Read from posn in the filename the binary values until the buffer is full.

View File

@@ -14,18 +14,32 @@ inline string Binary_File_IO::filename(int my_number)
}
template<class T>
void Binary_File_IO::write_to_file(const string filename, const vector< T >& buffer)
void Binary_File_IO::write_to_file(const string filename,
const vector<T>& buffer, long start_pos)
{
ofstream outf;
outf.open(filename, ios::out | ios::binary | ios::app);
outf.open(filename, ios::out | ios::binary | ios::ate | ios::in);
if (outf.fail()) { throw file_error(filename); }
if (start_pos != -1)
{
long write_pos = start_pos * T::size();
// fill with zeros if needed
for (long i = outf.tellp(); i < write_pos; i++)
outf.put(0);
outf.seekp(write_pos);
}
for (unsigned int i = 0; i < buffer.size(); i++)
{
buffer[i].output(outf, false);
}
if (outf.fail())
throw runtime_error("failed writing to " + filename);
outf.close();
}

View File

@@ -273,7 +273,6 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
get_vector(2, start, s);
break;
// open instructions + read/write instructions with variable length args
case WRITEFILESHARE:
case OPEN:
case GOPEN:
case MULS:
@@ -376,6 +375,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos)
case BITDECINT:
case EDABIT:
case SEDABIT:
case WRITEFILESHARE:
num_var_args = get_int(s) - 1;
r[0] = get_int(s);
get_vector(num_var_args, start, s);
@@ -1175,7 +1175,7 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
break;
case WRITEFILESHARE:
// Write shares to file system
Proc.write_shares_to_file(start);
Proc.write_shares_to_file(Proc.read_Ci(r[0]), start);
break;
case READFILESHARE:
// Read shares from file system

View File

@@ -94,6 +94,19 @@ Machine<sint, sgf2n>::Machine(int my_number, Names& playerNames,
load_schedule(progname_str);
// initialize persistence if necessary
for (auto& prog : progs)
{
if (prog.writes_persistance)
{
string filename = Binary_File_IO::filename(my_number);
ifstream pers(filename);
if (pers.fail())
ofstream pers(filename, ios::binary);
break;
}
}
#ifdef VERBOSE
progs[0].print_offline_cost();
#endif

View File

@@ -239,7 +239,7 @@ class Processor : public ArithmeticProcessor
// Read and write secret numeric data to file (name hardcoded at present)
void read_shares_from_file(int start_file_pos, int end_file_pos_register, const vector<int>& data_registers);
void write_shares_to_file(const vector<int>& data_registers);
void write_shares_to_file(long start_pos, const vector<int>& data_registers);
cint get_inverse2(unsigned m);

View File

@@ -370,7 +370,9 @@ void Processor<sint, sgf2n>::read_shares_from_file(int start_file_posn, int end_
// Append share data in data_registers to end of file. Expects Persistence directory to exist.
template<class sint, class sgf2n>
void Processor<sint, sgf2n>::write_shares_to_file(const vector<int>& data_registers) {
void Processor<sint, sgf2n>::write_shares_to_file(long start_pos,
const vector<int>& data_registers)
{
string filename = binary_file_io.filename(P.my_num());
unsigned int size = data_registers.size();
@@ -382,7 +384,7 @@ void Processor<sint, sgf2n>::write_shares_to_file(const vector<int>& data_regist
inpbuf[i] = get_Sp_ref(data_registers[i]);
}
binary_file_io.write_to_file(filename, inpbuf);
binary_file_io.write_to_file(filename, inpbuf, start_pos);
}
template <class T>