Fixed- and floating-point inputs.

This commit is contained in:
Marcel Keller
2019-07-11 14:59:18 +10:00
parent 8f0f25fa5b
commit 5ef70589cb
34 changed files with 304 additions and 83 deletions

View File

@@ -532,14 +532,16 @@ class InputAccess
GC::Secret<EvalRegister>& dest;
GC::Processor<GC::Secret<EvalRegister> >& processor;
ProgramParty& party;
InputArgs args;
public:
InputAccess(int from, int n_bits, GC::Secret<EvalRegister>& dest,
InputAccess(const InputArgs& args,
GC::Processor<GC::Secret<EvalRegister> >& processor) :
from(from), n_bits(n_bits), dest(dest), processor(processor), party(
ProgramParty::s())
from(args.from + 1), n_bits(args.n_bits), dest(
processor.S[args.dest]), processor(processor), party(
ProgramParty::s()), args(args)
{
if (from > party.get_n_parties() or n_bits > 100)
if (from > unsigned(party.get_n_parties()) or n_bits > 100)
throw runtime_error("invalid input parameters");
}
@@ -550,7 +552,7 @@ public:
party.load_wire(reg);
if (from == party.get_id())
{
long long in = processor.get_input(n_bits);
long long in = processor.get_input(args.params);
for (size_t i = 0; i < n_bits; i++)
{
auto& reg = dest.get_reg(i);
@@ -599,10 +601,10 @@ void EvalRegister::inputb(GC::Processor<GC::Secret<EvalRegister> >& processor,
vector<octetStream> oss(party.get_n_parties());
octetStream& my_os = oss[party.get_id() - 1];
vector<InputAccess> accesses;
for (size_t j = 0; j < args.size(); j += 3)
InputArgList a(args);
for (auto x : a)
{
accesses.push_back(
{ args[j] + 1, args[j + 1], processor.S[args[j + 2]], processor });
accesses.push_back({x , processor});
}
for (auto& access : accesses)
access.prepare_masks(my_os);

View File

@@ -17,6 +17,7 @@ using namespace std;
#include "GC/Clear.h"
#include "GC/Memory.h"
#include "GC/Access.h"
#include "GC/ArgTuples.h"
#include "Math/gf2n.h"
#include "Tools/FlexBuffer.h"
@@ -261,8 +262,8 @@ public:
// most BMR phases don't need actual input
template<class T>
static T get_input(int from, GC::Processor<T>& processor, int n_bits)
{ (void)processor; return T::input(from, 0, n_bits); }
static T get_input(GC::Processor<T>& processor, const InputArgs& args)
{ (void)processor; return T::input(args.from + 1, 0, args.n_bits); }
char get_output() { return 0; }
@@ -314,9 +315,9 @@ public:
static void inputb(T& processor, const vector<int>& args);
template <class T>
static T get_input(int from, GC::Processor<T>& processor, int n_bits)
static T get_input(GC::Processor<T>& processor, const InputArgs& args)
{
(void)from, (void)processor, (void)n_bits;
(void)processor, (void)args;
throw runtime_error("use EvalRegister::inputb()");
}

View File

@@ -180,7 +180,7 @@ class reveal(base.Instruction):
class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction):
__slots__ = []
code = opcodes['INPUTB']
arg_format = tools.cycle(['p','int','sbw'])
arg_format = tools.cycle(['p','int','int','sbw'])
class print_reg(base.IOInstruction):
code = base.opcodes['PRINTREG']

View File

@@ -212,7 +212,7 @@ class sbits(bits):
if n_bits is None:
n_bits = cls.n
res = cls()
inst.inputb(player, n_bits, res)
inst.inputb(player, n_bits, 0, res)
return res
# compatiblity to sint
get_raw_input_from = get_input_from
@@ -648,6 +648,11 @@ class sbitfix(_fix):
return sbitfixvec._new(sbitintvec(v))
else:
return super(sbitfix, cls).load_mem(address)
@classmethod
def get_input_from(cls, player):
v = cls.int_type()
inst.inputb(player, cls.k, cls.f, v)
return cls._new(v)
def __xor__(self, other):
return type(self)(self.v ^ other.v)
def __mul__(self, other):

View File

@@ -386,6 +386,7 @@ class Merger:
last_print_str = None
last = defaultdict(lambda: defaultdict(lambda: None))
last_open = deque()
last_text_input = None
depths = [0] * len(block.instructions)
self.depths = depths
@@ -471,6 +472,13 @@ class Merger:
else:
write(reg, n)
# will be merged
if isinstance(instr, TextInputInstruction):
if last_text_input is not None and \
type(block.instructions[last_text_input]) is not type(instr):
add_edge(last_text_input, n)
last_text_input = n
if isinstance(instr, merge_classes):
open_nodes.add(n)
G.add_node(n, merges=[])

View File

@@ -856,7 +856,7 @@ class prep(base.Instruction):
@base.gf2n
@base.vectorize
class asm_input(base.VarArgsInstruction):
class asm_input(base.TextInputInstruction):
r""" Receive input from player $p$ and put in register $s_i$. """
__slots__ = []
code = base.opcodes['INPUT']
@@ -870,6 +870,28 @@ class asm_input(base.VarArgsInstruction):
def execute(self):
self.args[0].value = _python_input("Enter player %d's input:" % self.args[1]) % program.P
class inputfix(base.TextInputInstruction):
__slots__ = []
code = base.opcodes['INPUTFIX']
arg_format = tools.cycle(['sw', 'int', 'p'])
field_type = 'modp'
def add_usage(self, req_node):
for player in self.args[2::3]:
req_node.increment((self.field_type, 'input', player), \
self.get_size())
class inputfloat(base.TextInputInstruction):
__slots__ = []
code = base.opcodes['INPUTFLOAT']
arg_format = tools.cycle(['sw', 'sw', 'sw', 'sw', 'int', 'p'])
field_type = 'modp'
def add_usage(self, req_node):
for player in self.args[5::6]:
req_node.increment((self.field_type, 'input', player), \
4 * self.get_size())
@base.gf2n
class startinput(base.RawInputInstruction):
r""" Receive inputs from player $p$. """

View File

@@ -100,6 +100,8 @@ opcodes = dict(
PREP = 0x57,
# Input
INPUT = 0x60,
INPUTFIX = 0xF0,
INPUTFLOAT = 0xF1,
STARTINPUT = 0x61,
STOPINPUT = 0x62,
READSOCKETC = 0x63,
@@ -592,6 +594,10 @@ class Instruction(object):
def __repr__(self):
return self.__class__.__name__ + '(' + self.get_pre_arg() + ','.join(str(a) for a in self.args) + ')'
class VarArgsInstruction(Instruction):
def has_var_args(self):
return True
###
### Basic arithmetic
###
@@ -692,6 +698,10 @@ class PublicFileIOInstruction(DoNotEliminateInstruction):
""" Instruction to reads/writes public information from/to files. """
__slots__ = []
class TextInputInstruction(VarArgsInstruction, DoNotEliminateInstruction):
""" Input from text file or stdin """
__slots__ = []
###
### Data access instructions
###
@@ -784,11 +794,6 @@ class JumpInstruction(Instruction):
return self.args[self.jump_arg]
class VarArgsInstruction(Instruction):
def has_var_args(self):
return True
class CISC(Instruction):
"""
Base class for a CISC instruction.

View File

@@ -79,7 +79,9 @@ class Program(object):
Compiler.instructions.dotprods_class, \
Compiler.instructions.gdotprods_class, \
Compiler.instructions.asm_input_class, \
Compiler.instructions.gasm_input_class]
Compiler.instructions.gasm_input_class,
Compiler.instructions.inputfix,
Compiler.instructions.inputfloat]
import Compiler.GC.instructions as gc
self.to_merge += [gc.ldmsdi, gc.stmsdi, gc.ldmsd, gc.stmsd, \
gc.stmsdci, gc.xors, gc.andrs, gc.ands, gc.inputb]

View File

@@ -2408,6 +2408,12 @@ class sfix(_fix):
int_type = sint
clear_type = cfix
@classmethod
def get_input_from(cls, player):
v = cls.int_type()
inputfix(v, cls.f, player)
return cls._new(v)
@classmethod
def coerce(cls, other):
return parse_type(other)
@@ -2728,6 +2734,15 @@ class sfloat(_number, _structure):
'with %d exponent bits' % (vv, plen))
return v, p, z, s
@classmethod
def get_input_from(cls, player):
v = sint()
p = sint()
z = sint()
s = sint()
inputfloat(v, p, z, s, cls.vlen, player)
return cls(v, p, z, s)
@vectorize_init
@read_mem_value
def __init__(self, v, p=None, z=None, s=None, size=None):

View File

@@ -27,7 +27,7 @@ public:
ArgIter<T> operator++()
{
auto res = it;
it += 3;
it += T::n;
return res;
}
@@ -64,16 +64,19 @@ public:
class InputArgs
{
public:
static const int n = 3;
static const int n = 4;
int from;
int n_bits;
int& n_bits;
int& n_shift;
int params[2];
int dest;
InputArgs(vector<int>::const_iterator it)
InputArgs(vector<int>::const_iterator it) : n_bits(params[0]), n_shift(params[1])
{
from = *it++;
n_bits = *it++;
n_shift = *it++;
dest = *it++;
}
};

View File

@@ -75,9 +75,9 @@ void FakeSecret::trans(Processor<FakeSecret>& processor, int n_outputs,
processor.S[args[i]] = square.rows[i];
}
FakeSecret FakeSecret::input(int from, GC::Processor<FakeSecret>& processor, int n_bits)
FakeSecret FakeSecret::input(GC::Processor<FakeSecret>& processor, const InputArgs& args)
{
return input(from, processor.get_input(n_bits), n_bits);
return input(args.from, processor.get_input(args.params), args.n_bits);
}
FakeSecret FakeSecret::input(int from, const int128& input, int n_bits)

View File

@@ -9,6 +9,7 @@
#include "GC/Clear.h"
#include "GC/Memory.h"
#include "GC/Access.h"
#include "GC/ArgTuples.h"
#include "Math/gf2nlong.h"
@@ -62,7 +63,7 @@ public:
static void convcbit(Integer& dest, const Clear& source) { dest = source; }
static FakeSecret input(int from, GC::Processor<FakeSecret>& processor, int n_bits);
static FakeSecret input(GC::Processor<FakeSecret>& processor, const InputArgs& args);
static FakeSecret input(int from, const int128& input, int n_bits);
FakeSecret() : a(0) {}

View File

@@ -37,7 +37,7 @@ class Processor : public ::ProcessorBase
public:
static int check_args(const vector<int>& args, int n);
static void check_input(long long in, int n_bits);
static void check_input(bigint in, int n_bits);
Machine<T>& machine;
@@ -61,7 +61,7 @@ public:
template<class U>
void reset(const U& program);
long long get_input(int n_bits, bool interactive = false);
long long get_input(const int* params, bool interactive = false);
void bitcoms(T& x, const vector<int>& regs) { x.bitcom(S, regs); }
void bitdecs(const vector<int>& regs, const T& x) { x.bitdec(S, regs); }

View File

@@ -11,6 +11,7 @@ using namespace std;
#include "GC/Program.h"
#include "Access.h"
#include "Processor/FixInput.h"
namespace GC
{
@@ -50,27 +51,29 @@ void Processor<T>::reset(const U& program)
}
template<class T>
inline long long GC::Processor<T>::get_input(int n_bits, bool interactive)
inline long long GC::Processor<T>::get_input(const int* params, bool interactive)
{
long long res = ProcessorBase::get_input(interactive);
bigint res = ProcessorBase::get_input<FixInput>(interactive, &params[1]).items[0];
int n_bits = *params;
check_input(res, n_bits);
return res;
assert(n_bits <= 64);
return res.get_si();
}
template<class T>
void GC::Processor<T>::check_input(long long in, int n_bits)
void GC::Processor<T>::check_input(bigint in, int n_bits)
{
auto test = in >> (n_bits - 1);
if (n_bits == 1)
{
if (not (in == 0 or in == 1))
throw runtime_error("input not a bit: " + to_string(in));
throw runtime_error("input not a bit: " + in.get_str());
}
else if (not (test == 0 or test == -1))
{
throw runtime_error(
"input too large for a " + std::to_string(n_bits)
+ "-bit signed integer: " + to_string(in));
+ "-bit signed integer: " + in.get_str());
}
}
@@ -182,11 +185,10 @@ void Processor<T>::and_(const vector<int>& args, bool repeat)
template <class T>
void Processor<T>::input(const vector<int>& args)
{
check_args(args, 3);
for (size_t i = 0; i < args.size(); i += 3)
InputArgList a(args);
for (auto x : a)
{
int n_bits = args[i + 1];
S[args[i+2]] = T::input(args[i] + 1, *this, n_bits);
S[x.dest] = T::input(*this, x);
#ifdef DEBUG_INPUT
cout << "input to " << args[i+2] << "/" << &S[args[i+2]] << endl;
#endif

View File

@@ -79,19 +79,16 @@ void ReplicatedSecret<U>::inputb(Processor<U>& processor,
party.os.resize(2);
for (auto& o : party.os)
o.reset_write_head();
processor.check_args(args, 3);
InputArgList a(args);
bool interactive = party.n_interactive_inputs_from_me(a) > 0;
for (size_t i = 0; i < args.size(); i += 3)
for (auto x : a)
{
int from = args[i];
int n_bits = args[i + 1];
if (from == party.P->my_num())
if (x.from == party.P->my_num())
{
auto& res = processor.S[args[i + 2]];
res.prepare_input(party.os, processor.get_input(n_bits, interactive), n_bits, party.secure_prng);
auto& res = processor.S[x.dest];
res.prepare_input(party.os, processor.get_input(x.params, interactive), x.n_bits, party.secure_prng);
}
}
@@ -101,23 +98,23 @@ void ReplicatedSecret<U>::inputb(Processor<U>& processor,
for (int i = 0; i < 2; i++)
party.P->pass_around(party.os[i], i + 1);
for (size_t i = 0; i < args.size(); i += 3)
for (auto x : a)
{
int from = args[i];
int n_bits = args[i + 1];
int from = x.from;
int n_bits = x.n_bits;
if (from != party.P->my_num())
{
auto& res = processor.S[args[i + 2]];
auto& res = processor.S[x.dest];
res.finalize_input(party, party.os[party.P->get_offset(from) == 1], from, n_bits);
}
}
}
template<class U>
U ReplicatedSecret<U>::input(int from, Processor<U>& processor, int n_bits)
U ReplicatedSecret<U>::input(Processor<U>& processor, const InputArgs& args)
{
// BMR stuff counts from 1
from--;
int from = args.from;
int n_bits = args.n_bits;
auto& party = ReplicatedParty<U>::s();
U res;
party.os.resize(2);
@@ -125,7 +122,7 @@ U ReplicatedSecret<U>::input(int from, Processor<U>& processor, int n_bits)
o.reset_write_head();
if (from == party.P->my_num())
{
res.prepare_input(party.os, processor.get_input(n_bits), n_bits, party.secure_prng);
res.prepare_input(party.os, processor.get_input(args.params), n_bits, party.secure_prng);
party.P->send_relative(party.os);
}
else

View File

@@ -12,6 +12,7 @@ using namespace std;
#include "GC/Memory.h"
#include "GC/Clear.h"
#include "GC/Access.h"
#include "GC/ArgTuples.h"
#include "Math/FixedVec.h"
#include "Math/BitVec.h"
#include "Tools/SwitchableOutput.h"
@@ -65,7 +66,7 @@ public:
static BitVec get_mask(int n) { return n >= 64 ? -1 : ((1L << n) - 1); }
static U input(int from, Processor<U>& processor, int n_bits);
static U input(Processor<U>& processor, const InputArgs& args);
void prepare_input(vector<octetStream>& os, long input, int n_bits, PRNG& secure_prng);
void finalize_input(Thread<U>& party, octetStream& o, int from, int n_bits);

View File

@@ -13,6 +13,7 @@
#include "GC/Clear.h"
#include "GC/Memory.h"
#include "GC/Access.h"
#include "GC/ArgTuples.h"
#include "Math/gf2nlong.h"
@@ -83,7 +84,7 @@ public:
static const T& cast(const T& reg) { return *reinterpret_cast<const T*>(&reg); }
static Secret<T> input(party_id_t from, const int128& input, int n_bits = -1);
static Secret<T> input(party_id_t from, Processor<Secret<T>>& processor, int n_bits = -1);
static Secret<T> input(Processor<Secret<T>>& processor, const InputArgs& args);
void random(int n_bits, int128 share);
void random_bit();
static Secret<T> reconstruct(const int128& x, int length);

View File

@@ -10,9 +10,9 @@ namespace GC
{
template<class T>
Secret<T> Secret<T>::input(party_id_t from, Processor<Secret<T>>& processor, int n_bits)
Secret<T> Secret<T>::input(Processor<Secret<T>>& processor, const InputArgs& args)
{
return T::get_input(from, processor, n_bits);
return T::get_input(processor, args);
}
template<class T>

View File

@@ -109,7 +109,8 @@ class gfp_
gfp_(const __m128i& x) { *this=x; }
gfp_(const int128& x) { *this=x.a; }
gfp_(const bigint& x) { to_modp(a, x, ZpD); }
gfp_(int x) { assign(x); }
gfp_(int x) { assign(x); }
gfp_(long x) { assign(x); }
gfp_(const void* buffer) { assign((char*)buffer); }
template<int Y>
gfp_(const gfp_<Y, L>& x);

15
Processor/FixInput.cpp Normal file
View File

@@ -0,0 +1,15 @@
/*
* FixInput.cpp
*
*/
#include "FixInput.h"
const char* FixInput::NAME = "real number";
void FixInput::read(std::istream& in, const int* params)
{
mpf_class x;
in >> x;
items[0] = x << *params;
}

25
Processor/FixInput.h Normal file
View File

@@ -0,0 +1,25 @@
/*
* FixInput.h
*
*/
#ifndef PROCESSOR_FIXINPUT_H_
#define PROCESSOR_FIXINPUT_H_
#include <iostream>
#include "Math/bigint.h"
class FixInput
{
public:
const static int N_DEST = 1;
const static int N_PARAM = 1;
const static char* NAME;
bigint items[N_DEST];
void read(std::istream& in, const int* params);
};
#endif /* PROCESSOR_FIXINPUT_H_ */

23
Processor/FloatInput.cpp Normal file
View File

@@ -0,0 +1,23 @@
/*
* FloatInput.cpp
*
*/
#include "FloatInput.h"
#include <math.h>
const char* FloatInput::NAME = "real number";
void FloatInput::read(std::istream& in, const int* params)
{
double x;
in >> x;
int exp;
double mant = fabs(frexp(x, &exp));
items[0] = round(mant * (1LL << *params));
items[1] = exp - *params;
items[2] = (x == 0);
items[3] = (x < 0);
}

25
Processor/FloatInput.h Normal file
View File

@@ -0,0 +1,25 @@
/*
* FloatInput.h
*
*/
#ifndef PROCESSOR_FLOATINPUT_H_
#define PROCESSOR_FLOATINPUT_H_
#include "Math/bigint.h"
#include <iostream>
class FloatInput
{
public:
const static int N_DEST = 4;
const static int N_PARAM = 1;
const static char* NAME;
long items[N_DEST];
void read(std::istream& in, const int* params);
};
#endif /* PROCESSOR_FLOATINPUT_H_ */

View File

@@ -31,6 +31,7 @@ protected:
public:
int values_input;
template<class U>
static void input(SubProcessor<T>& Proc, const vector<int>& args);
InputBase(ArithmeticProcessor* proc);

View File

@@ -184,35 +184,39 @@ T InputBase<T>::finalize(int player)
}
template<class T>
template<class U>
void InputBase<T>::input(SubProcessor<T>& Proc,
const vector<int>& args)
{
auto& input = Proc.input;
for (int i = 0; i < Proc.P.num_players(); i++)
input.reset(i);
assert(args.size() % 2 == 0);
int n_arg_tuple = U::N_DEST + U::N_PARAM + 1;
assert(args.size() % n_arg_tuple == 0);
int n_from_me = 0;
if (Proc.Proc.opts.interactive and Proc.Proc.thread_num == 0)
{
for (size_t i = 1; i < args.size(); i += 2)
for (size_t i = n_arg_tuple - 1; i < args.size(); i += n_arg_tuple)
n_from_me += (args[i] == Proc.P.my_num());
if (n_from_me > 0)
cout << "Please input " << n_from_me << " number(s):" << endl;
cout << "Please input " << n_from_me << " " << U::NAME << "(s):" << endl;
}
for (size_t i = 0; i < args.size(); i += 2)
for (size_t i = U::N_DEST; i < args.size(); i += n_arg_tuple)
{
int n = args[i + 1];
int n = args[i + U::N_PARAM];
if (n == Proc.P.my_num())
{
long x = Proc.Proc.get_input(n_from_me > 0);
input.add_mine(x);
U tuple = Proc.Proc.template get_input<U>(n_from_me > 0, &args[i]);
for (auto x : tuple.items)
input.add_mine(x);
}
else
{
input.add_other(n);
for (int j = 0; j < U::N_DEST; j++)
input.add_other(n);
}
}
@@ -222,9 +226,10 @@ void InputBase<T>::input(SubProcessor<T>& Proc,
input.send_mine();
vector<vector<int>> regs(Proc.P.num_players());
for (size_t i = 0; i < args.size(); i += 2)
for (size_t i = 0; i < args.size(); i += n_arg_tuple)
{
regs[args[i + 1]].push_back(args[i]);
for (int j = 0; j < U::N_DEST; j++)
regs[args[i + n_arg_tuple - 1]].push_back(args[i + j]);
}
for (int i = 0; i < Proc.P.num_players(); i++)
input.stop(i, regs[i]);

View File

@@ -101,6 +101,8 @@ enum
PREP = 0x57,
// Input
INPUT = 0x60,
INPUTFIX = 0xF0,
INPUTFLOAT = 0xF1,
STARTINPUT = 0x61,
STOPINPUT = 0x62,
READSOCKETC = 0x63,

View File

@@ -2,6 +2,9 @@
#include "Processor/Instruction.h"
#include "Processor/Machine.h"
#include "Processor/Processor.h"
#include "Processor/IntInput.h"
#include "Processor/FixInput.h"
#include "Processor/FloatInput.h"
#include "Exceptions/Exceptions.h"
#include "Tools/time-func.h"
#include "Tools/parse.h"
@@ -288,6 +291,8 @@ void BaseInstruction::parse_operands(istream& s, int pos)
case GDOTPRODS:
case INPUT:
case GINPUT:
case INPUTFIX:
case INPUTFLOAT:
num_var_args = get_int(s);
get_vector(num_var_args, start, s);
break;
@@ -987,10 +992,16 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Proc.temp.ans2.output(Proc.private_output, false);
break;
case INPUT:
sint::Input::input(Proc.Procp, start);
sint::Input::template input<IntInput>(Proc.Procp, start);
break;
case GINPUT:
sgf2n::Input::input(Proc.Proc2, start);
sgf2n::Input::template input<IntInput>(Proc.Proc2, start);
break;
case INPUTFIX:
sint::Input::template input<FixInput>(Proc.Procp, start);
break;
case INPUTFLOAT:
sint::Input::template input<FloatInput>(Proc.Procp, start);
break;
case STARTINPUT:
Proc.Procp.input.start(r[0],n);

14
Processor/IntInput.cpp Normal file
View File

@@ -0,0 +1,14 @@
/*
* IntInput.cpp
*
*/
#include "IntInput.h"
const char* IntInput::NAME = "integer";
void IntInput::read(std::istream& in, const int* params)
{
(void) params;
in >> items[0];
}

23
Processor/IntInput.h Normal file
View File

@@ -0,0 +1,23 @@
/*
* IntInput.h
*
*/
#ifndef PROCESSOR_INTINPUT_H_
#define PROCESSOR_INTINPUT_H_
#include <iostream>
class IntInput
{
public:
const static int N_DEST = 1;
const static int N_PARAM = 0;
const static char* NAME;
long items[N_DEST];
void read(std::istream& in, const int* params);
};
#endif /* PROCESSOR_INTINPUT_H_ */

View File

@@ -4,6 +4,9 @@
*/
#include "ProcessorBase.h"
#include "IntInput.h"
#include "FixInput.h"
#include "FloatInput.h"
#include "Exceptions/Exceptions.h"
#include <iostream>
@@ -23,21 +26,27 @@ void ProcessorBase::open_input_file(int my_num, int thread_num)
open_input_file(input_file);
}
long long ProcessorBase::get_input(bool interactive)
template<class T>
T ProcessorBase::get_input(bool interactive, const int* params)
{
if (interactive)
return get_input(cin, "standard input");
return get_input<T>(cin, "standard input", params);
else
return get_input(input_file, input_filename);
return get_input<T>(input_file, input_filename, params);
}
long long ProcessorBase::get_input(istream& input_file, const string& input_filename)
template<class T>
T ProcessorBase::get_input(istream& input_file, const string& input_filename, const int* params)
{
long long res;
input_file >> res;
T res;
res.read(input_file, params);
if (input_file.eof())
throw IO_Error("not enough inputs in " + input_filename);
if (input_file.fail())
throw IO_Error("cannot read from " + input_filename);
return res;
}
template IntInput ProcessorBase::get_input(bool, const int*);
template FixInput ProcessorBase::get_input(bool, const int*);
template FloatInput ProcessorBase::get_input(bool, const int*);

View File

@@ -40,8 +40,10 @@ public:
void open_input_file(const string& name);
void open_input_file(int my_num, int thread_num);
long long get_input(bool interactive);
long long get_input(istream& is, const string& input_filename);
template<class T>
T get_input(bool interactive, const int* params);
template<class T>
T get_input(istream& is, const string& input_filename, const int* params);
};
#endif /* PROCESSOR_PROCESSORBASE_H_ */

View File

@@ -102,7 +102,7 @@ data = Matrix(3, 2, sfix)
for i in range(3):
for j in range(2):
data[i][j] = sfix.from_sint(sint.get_input_from(j))
data[i][j] = sfix.get_input_from(j)
# compute weighted average

View File

@@ -119,7 +119,7 @@ void YaoEvalWire::inputb(GC::Processor<GC::Secret<YaoEvalWire> >& processor,
}
else
{
long long input = processor.get_input(x.n_bits, interactive);
long long input = processor.get_input(x.params, interactive);
size_t start = inputs.size();
inputs.resize(start + x.n_bits);
for (int i = 0; i < x.n_bits; i++)

View File

@@ -204,7 +204,7 @@ void YaoGarbleWire::inputb(GC::Processor<GC::Secret<YaoGarbleWire>>& processor,
dest.resize_regs(x.n_bits);
if (x.from == 0)
{
long long input = processor.get_input(x.n_bits, interactive);
long long input = processor.get_input(x.params, interactive);
for (auto& reg : dest.get_regs())
{
reg.public_input(input & 1);