diff --git a/Compiler/instructions.py b/Compiler/instructions.py index 9a2e88a2..63ad1173 100644 --- a/Compiler/instructions.py +++ b/Compiler/instructions.py @@ -913,7 +913,7 @@ class print_reg_plain(base.IOInstruction): class print_float_plain(base.IOInstruction): __slots__ = [] code = base.opcodes['PRINTFLOATPLAIN'] - arg_format = ['c', 'c', 'c'] + arg_format = ['c', 'c', 'c', 'c'] class print_char(base.IOInstruction): diff --git a/Compiler/library.py b/Compiler/library.py index 319f54d1..d32a7865 100644 --- a/Compiler/library.py +++ b/Compiler/library.py @@ -108,13 +108,7 @@ def print_str(s, *args): elif isinstance(val, sfix) or isinstance(val, sfloat): raise CompilerError('Cannot print secret value:', args[i]) elif isinstance(val, cfloat): - # Since we have only three registers, we separate the 0 case with the others - @if_e (val.z == 1) - def _(): - cint(0).print_reg_plain() - @else_ - def _(): - val.print_float_plain() + val.print_float_plain() elif isinstance(val, list): print_str('[' + ', '.join('%s' for i in range(len(val))) + ']', *val) else: diff --git a/Compiler/types.py b/Compiler/types.py index 3bd8ef78..8c0fa192 100644 --- a/Compiler/types.py +++ b/Compiler/types.py @@ -2037,21 +2037,12 @@ class sfloat(_number): class cfloat(object): # Helper class used for printing sfloats __slots__ = ['v', 'p', 'z', 's'] - instruction_type = 'modp' - size = 1 def __init__(self, v, p, z, s): - if not isinstance(v, cint) or not isinstance(p, cint) or not isinstance(z,cint) or not isinstance(s, cint): - raise CompilerError("Cfloat construction requires cints") - self.v = v - self.p = p - self.z = z - self.s = s + self.v, self.p, self.z, self.s = [cint.conv(x) for x in (v, p, z, s)] - @set_instruction_type - @vectorize def print_float_plain(self): - print_float_plain(self.v, self.p, self.s) + print_float_plain(self.v, self.p, self.z, self.s) _types = { 'c': cint, diff --git a/Math/gfp.h b/Math/gfp.h index f35ebb8f..5ca3afb5 100644 --- a/Math/gfp.h +++ b/Math/gfp.h @@ -93,7 +93,8 @@ class gfp bool is_zero() const { return isZero(a,ZpD); } - bool is_one() const { return isOne(a,ZpD); } + bool is_one() const { return isOne(a,ZpD); } + bool is_bit() const { return is_zero() or is_one(); } bool equal(const gfp& y) const { return areEqual(a,y.a,ZpD); } bool operator==(const gfp& y) const { return equal(y); } bool operator!=(const gfp& y) const { return !equal(y); } diff --git a/Processor/Instruction.cpp b/Processor/Instruction.cpp index d18243f9..8f104fa6 100644 --- a/Processor/Instruction.cpp +++ b/Processor/Instruction.cpp @@ -8,7 +8,6 @@ #include "Tools/time-func.h" #include -#include #include #include #include @@ -36,20 +35,6 @@ int get_int(istream& s) return n; } -long gfpToLong(bigint& tmp, gfp& g){ - to_bigint(tmp, g); - long ret = tmp.get_si(); - if (ret < 0){ - //Dirty trick to allow conversion - g.negate(); - to_bigint(tmp, g); - ret = tmp.get_si(); - ret *= -1; - g.negate(); - } - return ret; - } - // Convert modp to signed bigint of a given bit length void to_signed_bigint(bigint& bi, const gfp& x, int len) { @@ -138,7 +123,6 @@ void Instruction::parse(istream& s) case SUBINT: case MULINT: case DIVINT: - case PRINTFLOATPLAIN: r[0]=get_int(s); r[1]=get_int(s); r[2]=get_int(s); @@ -296,6 +280,10 @@ void Instruction::parse(istream& s) case CRASH: case CLOSESOCKET: break; + // instructions with 4 register operands + case PRINTFLOATPLAIN: + get_vector(4, start, s); + break; // open instructions case STARTOPEN: case STOPOPEN: @@ -1444,17 +1432,28 @@ void Instruction::execute(Processor& Proc) const break; case PRINTFLOATPLAIN: if (Proc.P.my_num() == 0) - { - gfp v = Proc.read_Cp(r[0]); - gfp p = Proc.read_Cp(r[1]); - gfp s = Proc.read_Cp(r[2]); - long lv = gfpToLong(Proc.temp.aa, v); - long lp = gfpToLong(Proc.temp.aa2, p); - double res = (double)lv * pow(2.0,lp); - if(!s.is_zero()) - res *= -1; - cout << res < 0) + mpf_mul_2exp(res.get_mpf_t(), res.get_mpf_t(), exp); + else + mpf_div_2exp(res.get_mpf_t(), res.get_mpf_t(), -exp); + if (z.is_one()) + res = 0; + if (!s.is_zero()) + res *= -1; + if (not z.is_bit() or not s.is_bit()) + throw Processor_Error("invalid floating point number"); + cout << res << flush; + } break; case PRINTSTR: if (Proc.P.my_num() == 0)