From b17d4b62544b2f95b29bc619d4ebeea7c98c4d06 Mon Sep 17 00:00:00 2001 From: Marcel Keller Date: Tue, 11 Feb 2020 21:11:40 +1100 Subject: [PATCH] Input by run-time player id. --- Compiler/instructions.py | 27 ++++++++++++++++++--------- Compiler/instructions_base.py | 1 + Compiler/types.py | 7 +++++++ Processor/Input.h | 4 +++- Processor/Input.hpp | 23 ++++++++++++++++++++--- Processor/Instruction.h | 1 + Processor/Instruction.hpp | 6 +++++- Processor/Processor.h | 19 ++++++++++--------- 8 files changed, 65 insertions(+), 23 deletions(-) diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 23db3fc1..9fda1346 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -918,10 +918,8 @@ class inputfloat(base.TextInputInstruction): req_node.increment((self.field_type, 'input', player), \ 4 * self.get_size()) -@base.vectorize -class inputmixed(base.TextInputInstruction): +class inputmixed_base(base.TextInputInstruction): __slots__ = [] - code = base.opcodes['INPUTMIXED'] field_type = 'modp' # the following has to match TYPE: (N_DEST, N_PARAM) types = { @@ -936,11 +934,8 @@ class inputmixed(base.TextInputInstruction): } def __init__(self, name, *args): - try: - type_id = self.type_ids[name] - except: - pass - super(inputmixed_class, self).__init__(type_id, *args) + type_id = self.type_ids[name] + super(inputmixed_base, self).__init__(type_id, *args) @property def arg_format(self): @@ -951,7 +946,7 @@ class inputmixed(base.TextInputInstruction): yield 'sw' for j in range(self.types[t][1]): yield 'int' - yield 'p' + yield self.player_arg_type def bases(self): i = 0 @@ -959,6 +954,11 @@ class inputmixed(base.TextInputInstruction): yield i i += sum(self.types[self.args[i]]) + 2 +@base.vectorize +class inputmixed(inputmixed_base): + code = base.opcodes['INPUTMIXED'] + player_arg_type = 'p' + def add_usage(self, req_node): for i in self.bases(): t = self.args[i] @@ -967,6 +967,15 @@ class inputmixed(base.TextInputInstruction): req_node.increment((self.field_type, 'input', player), \ n_dest * self.get_size()) +@base.vectorize +class inputmixedreg(inputmixed_base): + code = base.opcodes['INPUTMIXEDREG'] + player_arg_type = 'ci' + + def add_usage(self, req_node): + # player 0 as proxy + req_node.increment((self.field_type, 'input', 0), float('inf')) + @base.gf2n class startinput(base.RawInputInstruction): r""" Receive inputs from player $p$. """ diff --git a/Compiler/instructions_base.py b/Compiler/instructions_base.py index 89dff9ca..8508bab5 100644 --- a/Compiler/instructions_base.py +++ b/Compiler/instructions_base.py @@ -108,6 +108,7 @@ opcodes = dict( INPUTFIX = 0xF0, INPUTFLOAT = 0xF1, INPUTMIXED = 0xF2, + INPUTMIXEDREG = 0xF3, STARTINPUT = 0x61, STOPINPUT = 0x62, READSOCKETC = 0x63, diff --git a/Compiler/types.py b/Compiler/types.py index 13d6c110..ea8126a4 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -74,6 +74,7 @@ from .floatingpoint import two_power from . import comparison, floatingpoint import math from . import util +from . import instructions from .util import is_zero, is_one import operator from functools import reduce @@ -208,6 +209,12 @@ def read_mem_value(operation): copy_doc(read_mem_operation, operation) return read_mem_operation +def inputmixed(*args): + # helper to cover both cases + if isinstance(args[-1], int): + instructions.inputmixed(*args) + else: + instructions.inputmixedreg(*args) class _number(object): """ Number functionality. """ diff --git a/Processor/Input.h b/Processor/Input.h index ecc9b1f4..42593499 100644 --- a/Processor/Input.h +++ b/Processor/Input.h @@ -34,7 +34,9 @@ public: template static void input(SubProcessor& Proc, const vector& args, int size); - static void input_mixed(SubProcessor& Proc, const vector& args, int size); + static int get_player(SubProcessor& Proc, int arg, bool player_from_arg); + static void input_mixed(SubProcessor& Proc, const vector& args, + int size, bool player_from_reg); template static void prepare(SubProcessor& Proc, int player, const int* params, int size); template diff --git a/Processor/Input.hpp b/Processor/Input.hpp index 51f95930..8b741fc1 100644 --- a/Processor/Input.hpp +++ b/Processor/Input.hpp @@ -274,9 +274,24 @@ void InputBase::input(SubProcessor& Proc, } } +template +int InputBase::get_player(SubProcessor& Proc, int arg, bool player_from_reg) +{ + if (player_from_reg) + { + assert(Proc.Proc); + auto res = Proc.Proc->read_Ci(arg); + if (res >= Proc.P.num_players()) + throw runtime_error("player id too large: " + to_string(res)); + return res; + } + else + return arg; +} + template void InputBase::input_mixed(SubProcessor& Proc, const vector& args, - int size) + int size, bool player_from_reg) { auto& input = Proc.input; input.reset_all(Proc.P); @@ -293,7 +308,7 @@ void InputBase::input_mixed(SubProcessor& Proc, const vector& args, #define X(U) \ case U::TYPE: \ n_arg_tuple = U::N_DEST + U::N_PARAM + 2; \ - player = args[i + n_arg_tuple - 1]; \ + player = get_player(Proc, args[i + n_arg_tuple - 1], player_from_reg); \ if (type != last_type and Proc.Proc and Proc.Proc->use_stdin()) \ cout << "Please input " << U::NAME << "s:" << endl; \ prepare(Proc, player, &args[i + U::N_DEST + 1], size); \ @@ -313,12 +328,14 @@ void InputBase::input_mixed(SubProcessor& Proc, const vector& args, { int n_arg_tuple; int type = args[i]; + int player; switch (type) { #define X(U) \ case U::TYPE: \ n_arg_tuple = U::N_DEST + U::N_PARAM + 2; \ - finalize(Proc, args[i + n_arg_tuple - 1], &args[i + 1], size); \ + player = get_player(Proc, args[i + n_arg_tuple - 1], player_from_reg); \ + finalize(Proc, player, &args[i + 1], size); \ break; X(IntInput) X(FixInput) X(FloatInput) #undef X diff --git a/Processor/Instruction.h b/Processor/Instruction.h index aff1b82a..ee43c2af 100644 --- a/Processor/Instruction.h +++ b/Processor/Instruction.h @@ -109,6 +109,7 @@ enum INPUTFIX = 0xF0, INPUTFLOAT = 0xF1, INPUTMIXED = 0xF2, + INPUTMIXEDREG = 0xF3, STARTINPUT = 0x61, STOPINPUT = 0x62, READSOCKETC = 0x63, diff --git a/Processor/Instruction.hpp b/Processor/Instruction.hpp index b339681f..f735fa8a 100644 --- a/Processor/Instruction.hpp +++ b/Processor/Instruction.hpp @@ -298,6 +298,7 @@ void BaseInstruction::parse_operands(istream& s, int pos, int file_pos) case INPUTFIX: case INPUTFLOAT: case INPUTMIXED: + case INPUTMIXEDREG: case TRUNC_PR: num_var_args = get_int(s); get_vector(num_var_args, start, s); @@ -1145,7 +1146,10 @@ inline void Instruction::execute(Processor& Proc) const sint::Input::template input(Proc.Procp, start, size); return; case INPUTMIXED: - sint::Input::input_mixed(Proc.Procp, start, size); + sint::Input::input_mixed(Proc.Procp, start, size, false); + return; + case INPUTMIXEDREG: + sint::Input::input_mixed(Proc.Procp, start, size, true); return; case STARTINPUT: Proc.Procp.input.start(r[0],n); diff --git a/Processor/Processor.h b/Processor/Processor.h index 3915d1ac..73a5415f 100644 --- a/Processor/Processor.h +++ b/Processor/Processor.h @@ -80,6 +80,9 @@ public: class ArithmeticProcessor : public ProcessorBase { +protected: + vector Ci; + public: int thread_num; @@ -103,13 +106,18 @@ public: { return thread_num == 0 and opts.interactive; } + + const long& read_Ci(int i) const + { return Ci[i]; } + long& get_Ci_ref(int i) + { return Ci[i]; } + void write_Ci(int i,const long& x) + { Ci[i]=x; } }; template class Processor : public ArithmeticProcessor { - vector Ci; - int reg_max2,reg_maxp,reg_maxi; // Data structure used for reading/writing data to/from a socket (i.e. an external party to SPDZ) @@ -184,13 +192,6 @@ class Processor : public ArithmeticProcessor void write_Sp(int i,const sint & x) { Procp.S[i]=x; } - const long& read_Ci(int i) const - { return Ci[i]; } - long& get_Ci_ref(int i) - { return Ci[i]; } - void write_Ci(int i,const long& x) - { Ci[i]=x; } - void dabit(const Instruction& instruction); // Access to external client sockets for reading clear/shared data