Fixed- and floating-point inputs.

This commit is contained in:
Marcel Keller
2019-07-11 14:59:18 +10:00
parent 8f0f25fa5b
commit 5ef70589cb
34 changed files with 304 additions and 83 deletions

View File

@@ -532,14 +532,16 @@ class InputAccess
GC::Secret<EvalRegister>& dest; GC::Secret<EvalRegister>& dest;
GC::Processor<GC::Secret<EvalRegister> >& processor; GC::Processor<GC::Secret<EvalRegister> >& processor;
ProgramParty& party; ProgramParty& party;
InputArgs args;
public: public:
InputAccess(int from, int n_bits, GC::Secret<EvalRegister>& dest, InputAccess(const InputArgs& args,
GC::Processor<GC::Secret<EvalRegister> >& processor) : GC::Processor<GC::Secret<EvalRegister> >& processor) :
from(from), n_bits(n_bits), dest(dest), processor(processor), party( from(args.from + 1), n_bits(args.n_bits), dest(
ProgramParty::s()) processor.S[args.dest]), processor(processor), party(
ProgramParty::s()), args(args)
{ {
if (from > party.get_n_parties() or n_bits > 100) if (from > unsigned(party.get_n_parties()) or n_bits > 100)
throw runtime_error("invalid input parameters"); throw runtime_error("invalid input parameters");
} }
@@ -550,7 +552,7 @@ public:
party.load_wire(reg); party.load_wire(reg);
if (from == party.get_id()) if (from == party.get_id())
{ {
long long in = processor.get_input(n_bits); long long in = processor.get_input(args.params);
for (size_t i = 0; i < n_bits; i++) for (size_t i = 0; i < n_bits; i++)
{ {
auto& reg = dest.get_reg(i); auto& reg = dest.get_reg(i);
@@ -599,10 +601,10 @@ void EvalRegister::inputb(GC::Processor<GC::Secret<EvalRegister> >& processor,
vector<octetStream> oss(party.get_n_parties()); vector<octetStream> oss(party.get_n_parties());
octetStream& my_os = oss[party.get_id() - 1]; octetStream& my_os = oss[party.get_id() - 1];
vector<InputAccess> accesses; vector<InputAccess> accesses;
for (size_t j = 0; j < args.size(); j += 3) InputArgList a(args);
for (auto x : a)
{ {
accesses.push_back( accesses.push_back({x , processor});
{ args[j] + 1, args[j + 1], processor.S[args[j + 2]], processor });
} }
for (auto& access : accesses) for (auto& access : accesses)
access.prepare_masks(my_os); access.prepare_masks(my_os);

View File

@@ -17,6 +17,7 @@ using namespace std;
#include "GC/Clear.h" #include "GC/Clear.h"
#include "GC/Memory.h" #include "GC/Memory.h"
#include "GC/Access.h" #include "GC/Access.h"
#include "GC/ArgTuples.h"
#include "Math/gf2n.h" #include "Math/gf2n.h"
#include "Tools/FlexBuffer.h" #include "Tools/FlexBuffer.h"
@@ -261,8 +262,8 @@ public:
// most BMR phases don't need actual input // most BMR phases don't need actual input
template<class T> template<class T>
static T get_input(int from, GC::Processor<T>& processor, int n_bits) static T get_input(GC::Processor<T>& processor, const InputArgs& args)
{ (void)processor; return T::input(from, 0, n_bits); } { (void)processor; return T::input(args.from + 1, 0, args.n_bits); }
char get_output() { return 0; } char get_output() { return 0; }
@@ -314,9 +315,9 @@ public:
static void inputb(T& processor, const vector<int>& args); static void inputb(T& processor, const vector<int>& args);
template <class T> template <class T>
static T get_input(int from, GC::Processor<T>& processor, int n_bits) static T get_input(GC::Processor<T>& processor, const InputArgs& args)
{ {
(void)from, (void)processor, (void)n_bits; (void)processor, (void)args;
throw runtime_error("use EvalRegister::inputb()"); throw runtime_error("use EvalRegister::inputb()");
} }

View File

@@ -180,7 +180,7 @@ class reveal(base.Instruction):
class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction): class inputb(base.DoNotEliminateInstruction, base.VarArgsInstruction):
__slots__ = [] __slots__ = []
code = opcodes['INPUTB'] code = opcodes['INPUTB']
arg_format = tools.cycle(['p','int','sbw']) arg_format = tools.cycle(['p','int','int','sbw'])
class print_reg(base.IOInstruction): class print_reg(base.IOInstruction):
code = base.opcodes['PRINTREG'] code = base.opcodes['PRINTREG']

View File

@@ -212,7 +212,7 @@ class sbits(bits):
if n_bits is None: if n_bits is None:
n_bits = cls.n n_bits = cls.n
res = cls() res = cls()
inst.inputb(player, n_bits, res) inst.inputb(player, n_bits, 0, res)
return res return res
# compatiblity to sint # compatiblity to sint
get_raw_input_from = get_input_from get_raw_input_from = get_input_from
@@ -648,6 +648,11 @@ class sbitfix(_fix):
return sbitfixvec._new(sbitintvec(v)) return sbitfixvec._new(sbitintvec(v))
else: else:
return super(sbitfix, cls).load_mem(address) return super(sbitfix, cls).load_mem(address)
@classmethod
def get_input_from(cls, player):
v = cls.int_type()
inst.inputb(player, cls.k, cls.f, v)
return cls._new(v)
def __xor__(self, other): def __xor__(self, other):
return type(self)(self.v ^ other.v) return type(self)(self.v ^ other.v)
def __mul__(self, other): def __mul__(self, other):

View File

@@ -386,6 +386,7 @@ class Merger:
last_print_str = None last_print_str = None
last = defaultdict(lambda: defaultdict(lambda: None)) last = defaultdict(lambda: defaultdict(lambda: None))
last_open = deque() last_open = deque()
last_text_input = None
depths = [0] * len(block.instructions) depths = [0] * len(block.instructions)
self.depths = depths self.depths = depths
@@ -471,6 +472,13 @@ class Merger:
else: else:
write(reg, n) write(reg, n)
# will be merged
if isinstance(instr, TextInputInstruction):
if last_text_input is not None and \
type(block.instructions[last_text_input]) is not type(instr):
add_edge(last_text_input, n)
last_text_input = n
if isinstance(instr, merge_classes): if isinstance(instr, merge_classes):
open_nodes.add(n) open_nodes.add(n)
G.add_node(n, merges=[]) G.add_node(n, merges=[])

View File

@@ -856,7 +856,7 @@ class prep(base.Instruction):
@base.gf2n @base.gf2n
@base.vectorize @base.vectorize
class asm_input(base.VarArgsInstruction): class asm_input(base.TextInputInstruction):
r""" Receive input from player $p$ and put in register $s_i$. """ r""" Receive input from player $p$ and put in register $s_i$. """
__slots__ = [] __slots__ = []
code = base.opcodes['INPUT'] code = base.opcodes['INPUT']
@@ -870,6 +870,28 @@ class asm_input(base.VarArgsInstruction):
def execute(self): def execute(self):
self.args[0].value = _python_input("Enter player %d's input:" % self.args[1]) % program.P self.args[0].value = _python_input("Enter player %d's input:" % self.args[1]) % program.P
class inputfix(base.TextInputInstruction):
__slots__ = []
code = base.opcodes['INPUTFIX']
arg_format = tools.cycle(['sw', 'int', 'p'])
field_type = 'modp'
def add_usage(self, req_node):
for player in self.args[2::3]:
req_node.increment((self.field_type, 'input', player), \
self.get_size())
class inputfloat(base.TextInputInstruction):
__slots__ = []
code = base.opcodes['INPUTFLOAT']
arg_format = tools.cycle(['sw', 'sw', 'sw', 'sw', 'int', 'p'])
field_type = 'modp'
def add_usage(self, req_node):
for player in self.args[5::6]:
req_node.increment((self.field_type, 'input', player), \
4 * self.get_size())
@base.gf2n @base.gf2n
class startinput(base.RawInputInstruction): class startinput(base.RawInputInstruction):
r""" Receive inputs from player $p$. """ r""" Receive inputs from player $p$. """

View File

@@ -100,6 +100,8 @@ opcodes = dict(
PREP = 0x57, PREP = 0x57,
# Input # Input
INPUT = 0x60, INPUT = 0x60,
INPUTFIX = 0xF0,
INPUTFLOAT = 0xF1,
STARTINPUT = 0x61, STARTINPUT = 0x61,
STOPINPUT = 0x62, STOPINPUT = 0x62,
READSOCKETC = 0x63, READSOCKETC = 0x63,
@@ -592,6 +594,10 @@ class Instruction(object):
def __repr__(self): def __repr__(self):
return self.__class__.__name__ + '(' + self.get_pre_arg() + ','.join(str(a) for a in self.args) + ')' return self.__class__.__name__ + '(' + self.get_pre_arg() + ','.join(str(a) for a in self.args) + ')'
class VarArgsInstruction(Instruction):
def has_var_args(self):
return True
### ###
### Basic arithmetic ### Basic arithmetic
### ###
@@ -692,6 +698,10 @@ class PublicFileIOInstruction(DoNotEliminateInstruction):
""" Instruction to reads/writes public information from/to files. """ """ Instruction to reads/writes public information from/to files. """
__slots__ = [] __slots__ = []
class TextInputInstruction(VarArgsInstruction, DoNotEliminateInstruction):
""" Input from text file or stdin """
__slots__ = []
### ###
### Data access instructions ### Data access instructions
### ###
@@ -784,11 +794,6 @@ class JumpInstruction(Instruction):
return self.args[self.jump_arg] return self.args[self.jump_arg]
class VarArgsInstruction(Instruction):
def has_var_args(self):
return True
class CISC(Instruction): class CISC(Instruction):
""" """
Base class for a CISC instruction. Base class for a CISC instruction.

View File

@@ -79,7 +79,9 @@ class Program(object):
Compiler.instructions.dotprods_class, \ Compiler.instructions.dotprods_class, \
Compiler.instructions.gdotprods_class, \ Compiler.instructions.gdotprods_class, \
Compiler.instructions.asm_input_class, \ Compiler.instructions.asm_input_class, \
Compiler.instructions.gasm_input_class] Compiler.instructions.gasm_input_class,
Compiler.instructions.inputfix,
Compiler.instructions.inputfloat]
import Compiler.GC.instructions as gc import Compiler.GC.instructions as gc
self.to_merge += [gc.ldmsdi, gc.stmsdi, gc.ldmsd, gc.stmsd, \ self.to_merge += [gc.ldmsdi, gc.stmsdi, gc.ldmsd, gc.stmsd, \
gc.stmsdci, gc.xors, gc.andrs, gc.ands, gc.inputb] gc.stmsdci, gc.xors, gc.andrs, gc.ands, gc.inputb]

View File

@@ -2408,6 +2408,12 @@ class sfix(_fix):
int_type = sint int_type = sint
clear_type = cfix clear_type = cfix
@classmethod
def get_input_from(cls, player):
v = cls.int_type()
inputfix(v, cls.f, player)
return cls._new(v)
@classmethod @classmethod
def coerce(cls, other): def coerce(cls, other):
return parse_type(other) return parse_type(other)
@@ -2728,6 +2734,15 @@ class sfloat(_number, _structure):
'with %d exponent bits' % (vv, plen)) 'with %d exponent bits' % (vv, plen))
return v, p, z, s return v, p, z, s
@classmethod
def get_input_from(cls, player):
v = sint()
p = sint()
z = sint()
s = sint()
inputfloat(v, p, z, s, cls.vlen, player)
return cls(v, p, z, s)
@vectorize_init @vectorize_init
@read_mem_value @read_mem_value
def __init__(self, v, p=None, z=None, s=None, size=None): def __init__(self, v, p=None, z=None, s=None, size=None):

View File

@@ -27,7 +27,7 @@ public:
ArgIter<T> operator++() ArgIter<T> operator++()
{ {
auto res = it; auto res = it;
it += 3; it += T::n;
return res; return res;
} }
@@ -64,16 +64,19 @@ public:
class InputArgs class InputArgs
{ {
public: public:
static const int n = 3; static const int n = 4;
int from; int from;
int n_bits; int& n_bits;
int& n_shift;
int params[2];
int dest; int dest;
InputArgs(vector<int>::const_iterator it) InputArgs(vector<int>::const_iterator it) : n_bits(params[0]), n_shift(params[1])
{ {
from = *it++; from = *it++;
n_bits = *it++; n_bits = *it++;
n_shift = *it++;
dest = *it++; dest = *it++;
} }
}; };

View File

@@ -75,9 +75,9 @@ void FakeSecret::trans(Processor<FakeSecret>& processor, int n_outputs,
processor.S[args[i]] = square.rows[i]; processor.S[args[i]] = square.rows[i];
} }
FakeSecret FakeSecret::input(int from, GC::Processor<FakeSecret>& processor, int n_bits) FakeSecret FakeSecret::input(GC::Processor<FakeSecret>& processor, const InputArgs& args)
{ {
return input(from, processor.get_input(n_bits), n_bits); return input(args.from, processor.get_input(args.params), args.n_bits);
} }
FakeSecret FakeSecret::input(int from, const int128& input, int n_bits) FakeSecret FakeSecret::input(int from, const int128& input, int n_bits)

View File

@@ -9,6 +9,7 @@
#include "GC/Clear.h" #include "GC/Clear.h"
#include "GC/Memory.h" #include "GC/Memory.h"
#include "GC/Access.h" #include "GC/Access.h"
#include "GC/ArgTuples.h"
#include "Math/gf2nlong.h" #include "Math/gf2nlong.h"
@@ -62,7 +63,7 @@ public:
static void convcbit(Integer& dest, const Clear& source) { dest = source; } static void convcbit(Integer& dest, const Clear& source) { dest = source; }
static FakeSecret input(int from, GC::Processor<FakeSecret>& processor, int n_bits); static FakeSecret input(GC::Processor<FakeSecret>& processor, const InputArgs& args);
static FakeSecret input(int from, const int128& input, int n_bits); static FakeSecret input(int from, const int128& input, int n_bits);
FakeSecret() : a(0) {} FakeSecret() : a(0) {}

View File

@@ -37,7 +37,7 @@ class Processor : public ::ProcessorBase
public: public:
static int check_args(const vector<int>& args, int n); static int check_args(const vector<int>& args, int n);
static void check_input(long long in, int n_bits); static void check_input(bigint in, int n_bits);
Machine<T>& machine; Machine<T>& machine;
@@ -61,7 +61,7 @@ public:
template<class U> template<class U>
void reset(const U& program); void reset(const U& program);
long long get_input(int n_bits, bool interactive = false); long long get_input(const int* params, bool interactive = false);
void bitcoms(T& x, const vector<int>& regs) { x.bitcom(S, regs); } void bitcoms(T& x, const vector<int>& regs) { x.bitcom(S, regs); }
void bitdecs(const vector<int>& regs, const T& x) { x.bitdec(S, regs); } void bitdecs(const vector<int>& regs, const T& x) { x.bitdec(S, regs); }

View File

@@ -11,6 +11,7 @@ using namespace std;
#include "GC/Program.h" #include "GC/Program.h"
#include "Access.h" #include "Access.h"
#include "Processor/FixInput.h"
namespace GC namespace GC
{ {
@@ -50,27 +51,29 @@ void Processor<T>::reset(const U& program)
} }
template<class T> template<class T>
inline long long GC::Processor<T>::get_input(int n_bits, bool interactive) inline long long GC::Processor<T>::get_input(const int* params, bool interactive)
{ {
long long res = ProcessorBase::get_input(interactive); bigint res = ProcessorBase::get_input<FixInput>(interactive, &params[1]).items[0];
int n_bits = *params;
check_input(res, n_bits); check_input(res, n_bits);
return res; assert(n_bits <= 64);
return res.get_si();
} }
template<class T> template<class T>
void GC::Processor<T>::check_input(long long in, int n_bits) void GC::Processor<T>::check_input(bigint in, int n_bits)
{ {
auto test = in >> (n_bits - 1); auto test = in >> (n_bits - 1);
if (n_bits == 1) if (n_bits == 1)
{ {
if (not (in == 0 or in == 1)) if (not (in == 0 or in == 1))
throw runtime_error("input not a bit: " + to_string(in)); throw runtime_error("input not a bit: " + in.get_str());
} }
else if (not (test == 0 or test == -1)) else if (not (test == 0 or test == -1))
{ {
throw runtime_error( throw runtime_error(
"input too large for a " + std::to_string(n_bits) "input too large for a " + std::to_string(n_bits)
+ "-bit signed integer: " + to_string(in)); + "-bit signed integer: " + in.get_str());
} }
} }
@@ -182,11 +185,10 @@ void Processor<T>::and_(const vector<int>& args, bool repeat)
template <class T> template <class T>
void Processor<T>::input(const vector<int>& args) void Processor<T>::input(const vector<int>& args)
{ {
check_args(args, 3); InputArgList a(args);
for (size_t i = 0; i < args.size(); i += 3) for (auto x : a)
{ {
int n_bits = args[i + 1]; S[x.dest] = T::input(*this, x);
S[args[i+2]] = T::input(args[i] + 1, *this, n_bits);
#ifdef DEBUG_INPUT #ifdef DEBUG_INPUT
cout << "input to " << args[i+2] << "/" << &S[args[i+2]] << endl; cout << "input to " << args[i+2] << "/" << &S[args[i+2]] << endl;
#endif #endif

View File

@@ -79,19 +79,16 @@ void ReplicatedSecret<U>::inputb(Processor<U>& processor,
party.os.resize(2); party.os.resize(2);
for (auto& o : party.os) for (auto& o : party.os)
o.reset_write_head(); o.reset_write_head();
processor.check_args(args, 3);
InputArgList a(args); InputArgList a(args);
bool interactive = party.n_interactive_inputs_from_me(a) > 0; bool interactive = party.n_interactive_inputs_from_me(a) > 0;
for (size_t i = 0; i < args.size(); i += 3) for (auto x : a)
{ {
int from = args[i]; if (x.from == party.P->my_num())
int n_bits = args[i + 1];
if (from == party.P->my_num())
{ {
auto& res = processor.S[args[i + 2]]; auto& res = processor.S[x.dest];
res.prepare_input(party.os, processor.get_input(n_bits, interactive), n_bits, party.secure_prng); res.prepare_input(party.os, processor.get_input(x.params, interactive), x.n_bits, party.secure_prng);
} }
} }
@@ -101,23 +98,23 @@ void ReplicatedSecret<U>::inputb(Processor<U>& processor,
for (int i = 0; i < 2; i++) for (int i = 0; i < 2; i++)
party.P->pass_around(party.os[i], i + 1); party.P->pass_around(party.os[i], i + 1);
for (size_t i = 0; i < args.size(); i += 3) for (auto x : a)
{ {
int from = args[i]; int from = x.from;
int n_bits = args[i + 1]; int n_bits = x.n_bits;
if (from != party.P->my_num()) if (from != party.P->my_num())
{ {
auto& res = processor.S[args[i + 2]]; auto& res = processor.S[x.dest];
res.finalize_input(party, party.os[party.P->get_offset(from) == 1], from, n_bits); res.finalize_input(party, party.os[party.P->get_offset(from) == 1], from, n_bits);
} }
} }
} }
template<class U> template<class U>
U ReplicatedSecret<U>::input(int from, Processor<U>& processor, int n_bits) U ReplicatedSecret<U>::input(Processor<U>& processor, const InputArgs& args)
{ {
// BMR stuff counts from 1 int from = args.from;
from--; int n_bits = args.n_bits;
auto& party = ReplicatedParty<U>::s(); auto& party = ReplicatedParty<U>::s();
U res; U res;
party.os.resize(2); party.os.resize(2);
@@ -125,7 +122,7 @@ U ReplicatedSecret<U>::input(int from, Processor<U>& processor, int n_bits)
o.reset_write_head(); o.reset_write_head();
if (from == party.P->my_num()) if (from == party.P->my_num())
{ {
res.prepare_input(party.os, processor.get_input(n_bits), n_bits, party.secure_prng); res.prepare_input(party.os, processor.get_input(args.params), n_bits, party.secure_prng);
party.P->send_relative(party.os); party.P->send_relative(party.os);
} }
else else

View File

@@ -12,6 +12,7 @@ using namespace std;
#include "GC/Memory.h" #include "GC/Memory.h"
#include "GC/Clear.h" #include "GC/Clear.h"
#include "GC/Access.h" #include "GC/Access.h"
#include "GC/ArgTuples.h"
#include "Math/FixedVec.h" #include "Math/FixedVec.h"
#include "Math/BitVec.h" #include "Math/BitVec.h"
#include "Tools/SwitchableOutput.h" #include "Tools/SwitchableOutput.h"
@@ -65,7 +66,7 @@ public:
static BitVec get_mask(int n) { return n >= 64 ? -1 : ((1L << n) - 1); } static BitVec get_mask(int n) { return n >= 64 ? -1 : ((1L << n) - 1); }
static U input(int from, Processor<U>& processor, int n_bits); static U input(Processor<U>& processor, const InputArgs& args);
void prepare_input(vector<octetStream>& os, long input, int n_bits, PRNG& secure_prng); void prepare_input(vector<octetStream>& os, long input, int n_bits, PRNG& secure_prng);
void finalize_input(Thread<U>& party, octetStream& o, int from, int n_bits); void finalize_input(Thread<U>& party, octetStream& o, int from, int n_bits);

View File

@@ -13,6 +13,7 @@
#include "GC/Clear.h" #include "GC/Clear.h"
#include "GC/Memory.h" #include "GC/Memory.h"
#include "GC/Access.h" #include "GC/Access.h"
#include "GC/ArgTuples.h"
#include "Math/gf2nlong.h" #include "Math/gf2nlong.h"
@@ -83,7 +84,7 @@ public:
static const T& cast(const T& reg) { return *reinterpret_cast<const T*>(&reg); } static const T& cast(const T& reg) { return *reinterpret_cast<const T*>(&reg); }
static Secret<T> input(party_id_t from, const int128& input, int n_bits = -1); static Secret<T> input(party_id_t from, const int128& input, int n_bits = -1);
static Secret<T> input(party_id_t from, Processor<Secret<T>>& processor, int n_bits = -1); static Secret<T> input(Processor<Secret<T>>& processor, const InputArgs& args);
void random(int n_bits, int128 share); void random(int n_bits, int128 share);
void random_bit(); void random_bit();
static Secret<T> reconstruct(const int128& x, int length); static Secret<T> reconstruct(const int128& x, int length);

View File

@@ -10,9 +10,9 @@ namespace GC
{ {
template<class T> template<class T>
Secret<T> Secret<T>::input(party_id_t from, Processor<Secret<T>>& processor, int n_bits) Secret<T> Secret<T>::input(Processor<Secret<T>>& processor, const InputArgs& args)
{ {
return T::get_input(from, processor, n_bits); return T::get_input(processor, args);
} }
template<class T> template<class T>

View File

@@ -109,7 +109,8 @@ class gfp_
gfp_(const __m128i& x) { *this=x; } gfp_(const __m128i& x) { *this=x; }
gfp_(const int128& x) { *this=x.a; } gfp_(const int128& x) { *this=x.a; }
gfp_(const bigint& x) { to_modp(a, x, ZpD); } gfp_(const bigint& x) { to_modp(a, x, ZpD); }
gfp_(int x) { assign(x); } gfp_(int x) { assign(x); }
gfp_(long x) { assign(x); }
gfp_(const void* buffer) { assign((char*)buffer); } gfp_(const void* buffer) { assign((char*)buffer); }
template<int Y> template<int Y>
gfp_(const gfp_<Y, L>& x); gfp_(const gfp_<Y, L>& x);

15
Processor/FixInput.cpp Normal file
View File

@@ -0,0 +1,15 @@
/*
* FixInput.cpp
*
*/
#include "FixInput.h"
const char* FixInput::NAME = "real number";
void FixInput::read(std::istream& in, const int* params)
{
mpf_class x;
in >> x;
items[0] = x << *params;
}

25
Processor/FixInput.h Normal file
View File

@@ -0,0 +1,25 @@
/*
* FixInput.h
*
*/
#ifndef PROCESSOR_FIXINPUT_H_
#define PROCESSOR_FIXINPUT_H_
#include <iostream>
#include "Math/bigint.h"
class FixInput
{
public:
const static int N_DEST = 1;
const static int N_PARAM = 1;
const static char* NAME;
bigint items[N_DEST];
void read(std::istream& in, const int* params);
};
#endif /* PROCESSOR_FIXINPUT_H_ */

23
Processor/FloatInput.cpp Normal file
View File

@@ -0,0 +1,23 @@
/*
* FloatInput.cpp
*
*/
#include "FloatInput.h"
#include <math.h>
const char* FloatInput::NAME = "real number";
void FloatInput::read(std::istream& in, const int* params)
{
double x;
in >> x;
int exp;
double mant = fabs(frexp(x, &exp));
items[0] = round(mant * (1LL << *params));
items[1] = exp - *params;
items[2] = (x == 0);
items[3] = (x < 0);
}

25
Processor/FloatInput.h Normal file
View File

@@ -0,0 +1,25 @@
/*
* FloatInput.h
*
*/
#ifndef PROCESSOR_FLOATINPUT_H_
#define PROCESSOR_FLOATINPUT_H_
#include "Math/bigint.h"
#include <iostream>
class FloatInput
{
public:
const static int N_DEST = 4;
const static int N_PARAM = 1;
const static char* NAME;
long items[N_DEST];
void read(std::istream& in, const int* params);
};
#endif /* PROCESSOR_FLOATINPUT_H_ */

View File

@@ -31,6 +31,7 @@ protected:
public: public:
int values_input; int values_input;
template<class U>
static void input(SubProcessor<T>& Proc, const vector<int>& args); static void input(SubProcessor<T>& Proc, const vector<int>& args);
InputBase(ArithmeticProcessor* proc); InputBase(ArithmeticProcessor* proc);

View File

@@ -184,35 +184,39 @@ T InputBase<T>::finalize(int player)
} }
template<class T> template<class T>
template<class U>
void InputBase<T>::input(SubProcessor<T>& Proc, void InputBase<T>::input(SubProcessor<T>& Proc,
const vector<int>& args) const vector<int>& args)
{ {
auto& input = Proc.input; auto& input = Proc.input;
for (int i = 0; i < Proc.P.num_players(); i++) for (int i = 0; i < Proc.P.num_players(); i++)
input.reset(i); input.reset(i);
assert(args.size() % 2 == 0); int n_arg_tuple = U::N_DEST + U::N_PARAM + 1;
assert(args.size() % n_arg_tuple == 0);
int n_from_me = 0; int n_from_me = 0;
if (Proc.Proc.opts.interactive and Proc.Proc.thread_num == 0) if (Proc.Proc.opts.interactive and Proc.Proc.thread_num == 0)
{ {
for (size_t i = 1; i < args.size(); i += 2) for (size_t i = n_arg_tuple - 1; i < args.size(); i += n_arg_tuple)
n_from_me += (args[i] == Proc.P.my_num()); n_from_me += (args[i] == Proc.P.my_num());
if (n_from_me > 0) if (n_from_me > 0)
cout << "Please input " << n_from_me << " number(s):" << endl; cout << "Please input " << n_from_me << " " << U::NAME << "(s):" << endl;
} }
for (size_t i = 0; i < args.size(); i += 2) for (size_t i = U::N_DEST; i < args.size(); i += n_arg_tuple)
{ {
int n = args[i + 1]; int n = args[i + U::N_PARAM];
if (n == Proc.P.my_num()) if (n == Proc.P.my_num())
{ {
long x = Proc.Proc.get_input(n_from_me > 0); U tuple = Proc.Proc.template get_input<U>(n_from_me > 0, &args[i]);
input.add_mine(x); for (auto x : tuple.items)
input.add_mine(x);
} }
else else
{ {
input.add_other(n); for (int j = 0; j < U::N_DEST; j++)
input.add_other(n);
} }
} }
@@ -222,9 +226,10 @@ void InputBase<T>::input(SubProcessor<T>& Proc,
input.send_mine(); input.send_mine();
vector<vector<int>> regs(Proc.P.num_players()); vector<vector<int>> regs(Proc.P.num_players());
for (size_t i = 0; i < args.size(); i += 2) for (size_t i = 0; i < args.size(); i += n_arg_tuple)
{ {
regs[args[i + 1]].push_back(args[i]); for (int j = 0; j < U::N_DEST; j++)
regs[args[i + n_arg_tuple - 1]].push_back(args[i + j]);
} }
for (int i = 0; i < Proc.P.num_players(); i++) for (int i = 0; i < Proc.P.num_players(); i++)
input.stop(i, regs[i]); input.stop(i, regs[i]);

View File

@@ -101,6 +101,8 @@ enum
PREP = 0x57, PREP = 0x57,
// Input // Input
INPUT = 0x60, INPUT = 0x60,
INPUTFIX = 0xF0,
INPUTFLOAT = 0xF1,
STARTINPUT = 0x61, STARTINPUT = 0x61,
STOPINPUT = 0x62, STOPINPUT = 0x62,
READSOCKETC = 0x63, READSOCKETC = 0x63,

View File

@@ -2,6 +2,9 @@
#include "Processor/Instruction.h" #include "Processor/Instruction.h"
#include "Processor/Machine.h" #include "Processor/Machine.h"
#include "Processor/Processor.h" #include "Processor/Processor.h"
#include "Processor/IntInput.h"
#include "Processor/FixInput.h"
#include "Processor/FloatInput.h"
#include "Exceptions/Exceptions.h" #include "Exceptions/Exceptions.h"
#include "Tools/time-func.h" #include "Tools/time-func.h"
#include "Tools/parse.h" #include "Tools/parse.h"
@@ -288,6 +291,8 @@ void BaseInstruction::parse_operands(istream& s, int pos)
case GDOTPRODS: case GDOTPRODS:
case INPUT: case INPUT:
case GINPUT: case GINPUT:
case INPUTFIX:
case INPUTFLOAT:
num_var_args = get_int(s); num_var_args = get_int(s);
get_vector(num_var_args, start, s); get_vector(num_var_args, start, s);
break; break;
@@ -987,10 +992,16 @@ inline void Instruction::execute(Processor<sint, sgf2n>& Proc) const
Proc.temp.ans2.output(Proc.private_output, false); Proc.temp.ans2.output(Proc.private_output, false);
break; break;
case INPUT: case INPUT:
sint::Input::input(Proc.Procp, start); sint::Input::template input<IntInput>(Proc.Procp, start);
break; break;
case GINPUT: case GINPUT:
sgf2n::Input::input(Proc.Proc2, start); sgf2n::Input::template input<IntInput>(Proc.Proc2, start);
break;
case INPUTFIX:
sint::Input::template input<FixInput>(Proc.Procp, start);
break;
case INPUTFLOAT:
sint::Input::template input<FloatInput>(Proc.Procp, start);
break; break;
case STARTINPUT: case STARTINPUT:
Proc.Procp.input.start(r[0],n); Proc.Procp.input.start(r[0],n);

14
Processor/IntInput.cpp Normal file
View File

@@ -0,0 +1,14 @@
/*
* IntInput.cpp
*
*/
#include "IntInput.h"
const char* IntInput::NAME = "integer";
void IntInput::read(std::istream& in, const int* params)
{
(void) params;
in >> items[0];
}

23
Processor/IntInput.h Normal file
View File

@@ -0,0 +1,23 @@
/*
* IntInput.h
*
*/
#ifndef PROCESSOR_INTINPUT_H_
#define PROCESSOR_INTINPUT_H_
#include <iostream>
class IntInput
{
public:
const static int N_DEST = 1;
const static int N_PARAM = 0;
const static char* NAME;
long items[N_DEST];
void read(std::istream& in, const int* params);
};
#endif /* PROCESSOR_INTINPUT_H_ */

View File

@@ -4,6 +4,9 @@
*/ */
#include "ProcessorBase.h" #include "ProcessorBase.h"
#include "IntInput.h"
#include "FixInput.h"
#include "FloatInput.h"
#include "Exceptions/Exceptions.h" #include "Exceptions/Exceptions.h"
#include <iostream> #include <iostream>
@@ -23,21 +26,27 @@ void ProcessorBase::open_input_file(int my_num, int thread_num)
open_input_file(input_file); open_input_file(input_file);
} }
long long ProcessorBase::get_input(bool interactive) template<class T>
T ProcessorBase::get_input(bool interactive, const int* params)
{ {
if (interactive) if (interactive)
return get_input(cin, "standard input"); return get_input<T>(cin, "standard input", params);
else else
return get_input(input_file, input_filename); return get_input<T>(input_file, input_filename, params);
} }
long long ProcessorBase::get_input(istream& input_file, const string& input_filename) template<class T>
T ProcessorBase::get_input(istream& input_file, const string& input_filename, const int* params)
{ {
long long res; T res;
input_file >> res; res.read(input_file, params);
if (input_file.eof()) if (input_file.eof())
throw IO_Error("not enough inputs in " + input_filename); throw IO_Error("not enough inputs in " + input_filename);
if (input_file.fail()) if (input_file.fail())
throw IO_Error("cannot read from " + input_filename); throw IO_Error("cannot read from " + input_filename);
return res; return res;
} }
template IntInput ProcessorBase::get_input(bool, const int*);
template FixInput ProcessorBase::get_input(bool, const int*);
template FloatInput ProcessorBase::get_input(bool, const int*);

View File

@@ -40,8 +40,10 @@ public:
void open_input_file(const string& name); void open_input_file(const string& name);
void open_input_file(int my_num, int thread_num); void open_input_file(int my_num, int thread_num);
long long get_input(bool interactive); template<class T>
long long get_input(istream& is, const string& input_filename); T get_input(bool interactive, const int* params);
template<class T>
T get_input(istream& is, const string& input_filename, const int* params);
}; };
#endif /* PROCESSOR_PROCESSORBASE_H_ */ #endif /* PROCESSOR_PROCESSORBASE_H_ */

View File

@@ -102,7 +102,7 @@ data = Matrix(3, 2, sfix)
for i in range(3): for i in range(3):
for j in range(2): for j in range(2):
data[i][j] = sfix.from_sint(sint.get_input_from(j)) data[i][j] = sfix.get_input_from(j)
# compute weighted average # compute weighted average

View File

@@ -119,7 +119,7 @@ void YaoEvalWire::inputb(GC::Processor<GC::Secret<YaoEvalWire> >& processor,
} }
else else
{ {
long long input = processor.get_input(x.n_bits, interactive); long long input = processor.get_input(x.params, interactive);
size_t start = inputs.size(); size_t start = inputs.size();
inputs.resize(start + x.n_bits); inputs.resize(start + x.n_bits);
for (int i = 0; i < x.n_bits; i++) for (int i = 0; i < x.n_bits; i++)

View File

@@ -204,7 +204,7 @@ void YaoGarbleWire::inputb(GC::Processor<GC::Secret<YaoGarbleWire>>& processor,
dest.resize_regs(x.n_bits); dest.resize_regs(x.n_bits);
if (x.from == 0) if (x.from == 0)
{ {
long long input = processor.get_input(x.n_bits, interactive); long long input = processor.get_input(x.params, interactive);
for (auto& reg : dest.get_regs()) for (auto& reg : dest.get_regs())
{ {
reg.public_input(input & 1); reg.public_input(input & 1);