mirror of
https://github.com/data61/MP-SPDZ.git
synced 2026-01-09 21:48:11 -05:00
190 lines
6.6 KiB
Plaintext
190 lines
6.6 KiB
Plaintext
from Compiler import instructions_base
|
|
import sys
|
|
|
|
program.bit_length = 128
|
|
nparallel = int(sys.argv[2])
|
|
|
|
instructions_base.set_global_vector_size(nparallel)
|
|
PREPROC = True
|
|
# To disable key-dependent pre-processing set PREPROC=False
|
|
|
|
class KDPreprocessing(object):
|
|
""" Generate Key-Dependent Pre-processing """
|
|
|
|
def __init__(self, num_calls, keys, nonsq, bit_len):
|
|
# legendre has a key for each bit
|
|
self.bit_len = bit_len
|
|
pre_squares = VectorArray(num_calls * bit_len, sint, nparallel)
|
|
pre_bits = VectorArray(num_calls * bit_len, sint, nparallel)
|
|
pre_keys = VectorArray(num_calls * bit_len, sint, nparallel)
|
|
self.num_calls = num_calls
|
|
self.index = MemValue(regint(0))
|
|
|
|
# for each call generate pre-processing data
|
|
@for_range(num_calls)
|
|
def block(row_index):
|
|
cur_row = row_index * bit_len
|
|
for i in range(bit_len):
|
|
s, ssq = sint.get_random_square()
|
|
b = sint.get_random_bit()
|
|
t = ssq * (b + nonsq * (1 - b))
|
|
|
|
next_row = cur_row + i
|
|
pre_squares[next_row] = t
|
|
pre_bits[next_row] = b
|
|
pre_keys[next_row] = t * keys[i]
|
|
|
|
# [r^2], [b], [k_i * r^2]
|
|
self.pre_squares = pre_squares
|
|
self.pre_bits = pre_bits
|
|
self.pre_keys = pre_keys
|
|
|
|
def get_ss_to_ss(self, where):
|
|
index = self.index
|
|
target = index * self.bit_len + where
|
|
s, b = self.pre_squares[target], self.pre_bits[target]
|
|
|
|
return s, b
|
|
|
|
def get_clear_to_ss(self, where):
|
|
index = self.index
|
|
target = index * self.bit_len + where
|
|
s, b, k = self.pre_squares[target], self.pre_bits[target], self.pre_keys[target]
|
|
|
|
return s, b, k
|
|
|
|
def gen_next_pre(self):
|
|
self.index.iadd(1)
|
|
|
|
def reset(self):
|
|
self.index.imul(0)
|
|
|
|
class LegPRF(object):
|
|
class PRFHelper(object):
|
|
""" Helper functions, independent of the PRF key """
|
|
def __init__(self, noutput):
|
|
self.twos = [cint(2)**i for i in range(noutput)]
|
|
self.noutput = noutput
|
|
|
|
def bit_compose(self, prf_out):
|
|
return sum(self.twos[i] * prf_out[i] for i in range(self.noutput))
|
|
|
|
def __init__(self, num_calls, bit_len):
|
|
self.prf_bit_len = bit_len
|
|
self.prf_helper = LegPRF.PRFHelper(bit_len)
|
|
# quadratic non-residue modulo p (\alpha in paper)
|
|
self.precomp = cint(3482426629118966072265963355631718388)
|
|
|
|
self.key = Array(bit_len, sint)
|
|
@for_range(bit_len)
|
|
def block(i):
|
|
self.key[i] = sint.get_random_triple()[0]
|
|
|
|
if len(self.key) != self.prf_bit_len:
|
|
raise CompilerError('Invalid key length to Legendre PRF')
|
|
|
|
nonsq = self.precomp
|
|
self.kd_pre = KDPreprocessing(num_calls, self.key, nonsq,
|
|
self.prf_bit_len)
|
|
|
|
def reset_kd_pre(self):
|
|
self.kd_pre.reset()
|
|
|
|
def encrypt_ss_to_ss(self, m):
|
|
""" Takes as input [m], returns [c] """
|
|
if PREPROC:
|
|
prf_out = self.pre_prf_legendre_ss_to_ss(m, self.precomp)
|
|
self.kd_pre.gen_next_pre()
|
|
else:
|
|
prf_out = self.prf_legendre_ss_to_ss(m, self.precomp)
|
|
return self.prf_helper.bit_compose(prf_out)
|
|
|
|
def pre_prf_legendre_ss_to_ss(self, msg, r_nonsq):
|
|
""" Variant with Key-Dependent Pre-processing """
|
|
output_len = self.prf_bit_len
|
|
res = [0] * output_len
|
|
for i in range(output_len):
|
|
sq, bit = self.kd_pre.get_ss_to_ss(where=i)
|
|
v = (sq * (self.key[i] + msg)).reveal()
|
|
res[i] = (v.legendre() * (2*bit-1) + 1) / 2
|
|
return res
|
|
|
|
def prf_legendre_ss_to_ss(self, msg, r_nonsq):
|
|
""" Variant without Key-Dependent Pre-processing """
|
|
output_len = self.prf_bit_len
|
|
res = [0] * output_len
|
|
for i in range(output_len):
|
|
s, ssq = sint.get_random_square()
|
|
b = sint.get_random_bit()
|
|
nonsq = r_nonsq*ssq
|
|
t = b*(ssq - nonsq) + nonsq
|
|
|
|
v = (t * (self.key[i] + msg)).reveal()
|
|
res[i] = (v.legendre() * (2*b - 1) + 1) / 2
|
|
return res
|
|
|
|
def encrypt_clear_to_ss(self, m):
|
|
""" Takes as input m, returns [c] """
|
|
if PREPROC:
|
|
prf_out = self.pre_prf_legendre_clear_to_ss(m, self.precomp)
|
|
self.kd_pre.gen_next_pre()
|
|
else:
|
|
prf_out = self.prf_legendre_clear_to_ss(m, self.precomp)
|
|
return self.prf_helper.bit_compose(prf_out)
|
|
|
|
def pre_prf_legendre_clear_to_ss(self, msg, r_nonsq):
|
|
""" Variant with Key-Dependent Pre-processing """
|
|
output_len = self.prf_bit_len
|
|
res = [0] * output_len
|
|
for i in range(output_len):
|
|
sq, bit, prk = self.kd_pre.get_clear_to_ss(where=i)
|
|
v = (prk + sq * msg).reveal()
|
|
res[i] = (v.legendre() * (2*bit-1) + 1) / 2
|
|
return res
|
|
|
|
def prf_legendre_clear_to_ss(self, msg, r_nonsq):
|
|
""" Variant without Key-Dependent Pre-processing """
|
|
output_len = self.prf_bit_len
|
|
res = [0] * output_len
|
|
for i in range(output_len):
|
|
s, ssq = sint.get_random_square()
|
|
b = sint.get_random_bit()
|
|
nonsq = r_nonsq*ssq
|
|
t = b*(ssq - nonsq) + nonsq
|
|
v = (t * self.key[i] + t * msg).reveal()
|
|
res[i] = (v.legendre() * (2*b - 1) + 1) / 2
|
|
return res
|
|
|
|
def encrypt_ss_to_clear(self, m):
|
|
""" Takes as input [m], returns c """
|
|
prf_out = self.prf_legendre_ss_to_clear(m)
|
|
return self.prf_helper.bit_compose(prf_out)
|
|
|
|
def pre_prf_legendre_ss_to_clear(self, msg):
|
|
return self.prf_legendre_ss_to_clear(msg)
|
|
|
|
def prf_legendre_ss_to_clear(self, msg):
|
|
""" Variant with / without Key-Dependent Pre-processing """
|
|
output_len = self.prf_bit_len
|
|
res = [0] * output_len
|
|
for i in range(output_len):
|
|
s, ssq = sint.get_random_square()
|
|
v = (ssq * (self.key[i] + msg)).reveal()
|
|
res[i] = (v.legendre() + 1) / 2
|
|
return res
|
|
|
|
def encrypt_clear_to_clear(self, m):
|
|
""" Takes as input m, returns c """
|
|
prf_out = self.prf_legendre_clear_to_clear(m)
|
|
return self.prf_helper.bit_compose(prf_out)
|
|
|
|
def prf_legendre_clear_to_clear(self, msg):
|
|
""" Variant with / without Key-Dependent Pre-processing """
|
|
output_len = self.prf_bit_len
|
|
res = [0] * output_len
|
|
for i in range(output_len):
|
|
s, ssq = sint.get_random_square()
|
|
v = (ssq * self.key[i] + ssq * msg).reveal()
|
|
res[i] = (v.legendre() + 1) / 2
|
|
return res
|