Files
MP-SPDZ/Programs/Source/prf_leg.mpc
2019-05-23 14:35:35 +02:00

190 lines
6.6 KiB
Plaintext

from Compiler import instructions_base
import sys
program.bit_length = 128
nparallel = int(sys.argv[2])
instructions_base.set_global_vector_size(nparallel)
PREPROC = True
# To disable key-dependent pre-processing set PREPROC=False
class KDPreprocessing(object):
""" Generate Key-Dependent Pre-processing """
def __init__(self, num_calls, keys, nonsq, bit_len):
# legendre has a key for each bit
self.bit_len = bit_len
pre_squares = VectorArray(num_calls * bit_len, sint, nparallel)
pre_bits = VectorArray(num_calls * bit_len, sint, nparallel)
pre_keys = VectorArray(num_calls * bit_len, sint, nparallel)
self.num_calls = num_calls
self.index = MemValue(regint(0))
# for each call generate pre-processing data
@for_range(num_calls)
def block(row_index):
cur_row = row_index * bit_len
for i in range(bit_len):
s, ssq = sint.get_random_square()
b = sint.get_random_bit()
t = ssq * (b + nonsq * (1 - b))
next_row = cur_row + i
pre_squares[next_row] = t
pre_bits[next_row] = b
pre_keys[next_row] = t * keys[i]
# [r^2], [b], [k_i * r^2]
self.pre_squares = pre_squares
self.pre_bits = pre_bits
self.pre_keys = pre_keys
def get_ss_to_ss(self, where):
index = self.index
target = index * self.bit_len + where
s, b = self.pre_squares[target], self.pre_bits[target]
return s, b
def get_clear_to_ss(self, where):
index = self.index
target = index * self.bit_len + where
s, b, k = self.pre_squares[target], self.pre_bits[target], self.pre_keys[target]
return s, b, k
def gen_next_pre(self):
self.index.iadd(1)
def reset(self):
self.index.imul(0)
class LegPRF(object):
class PRFHelper(object):
""" Helper functions, independent of the PRF key """
def __init__(self, noutput):
self.twos = [cint(2)**i for i in range(noutput)]
self.noutput = noutput
def bit_compose(self, prf_out):
return sum(self.twos[i] * prf_out[i] for i in range(self.noutput))
def __init__(self, num_calls, bit_len):
self.prf_bit_len = bit_len
self.prf_helper = LegPRF.PRFHelper(bit_len)
# quadratic non-residue modulo p (\alpha in paper)
self.precomp = cint(3482426629118966072265963355631718388)
self.key = Array(bit_len, sint)
@for_range(bit_len)
def block(i):
self.key[i] = sint.get_random_triple()[0]
if len(self.key) != self.prf_bit_len:
raise CompilerError('Invalid key length to Legendre PRF')
nonsq = self.precomp
self.kd_pre = KDPreprocessing(num_calls, self.key, nonsq,
self.prf_bit_len)
def reset_kd_pre(self):
self.kd_pre.reset()
def encrypt_ss_to_ss(self, m):
""" Takes as input [m], returns [c] """
if PREPROC:
prf_out = self.pre_prf_legendre_ss_to_ss(m, self.precomp)
self.kd_pre.gen_next_pre()
else:
prf_out = self.prf_legendre_ss_to_ss(m, self.precomp)
return self.prf_helper.bit_compose(prf_out)
def pre_prf_legendre_ss_to_ss(self, msg, r_nonsq):
""" Variant with Key-Dependent Pre-processing """
output_len = self.prf_bit_len
res = [0] * output_len
for i in range(output_len):
sq, bit = self.kd_pre.get_ss_to_ss(where=i)
v = (sq * (self.key[i] + msg)).reveal()
res[i] = (v.legendre() * (2*bit-1) + 1) / 2
return res
def prf_legendre_ss_to_ss(self, msg, r_nonsq):
""" Variant without Key-Dependent Pre-processing """
output_len = self.prf_bit_len
res = [0] * output_len
for i in range(output_len):
s, ssq = sint.get_random_square()
b = sint.get_random_bit()
nonsq = r_nonsq*ssq
t = b*(ssq - nonsq) + nonsq
v = (t * (self.key[i] + msg)).reveal()
res[i] = (v.legendre() * (2*b - 1) + 1) / 2
return res
def encrypt_clear_to_ss(self, m):
""" Takes as input m, returns [c] """
if PREPROC:
prf_out = self.pre_prf_legendre_clear_to_ss(m, self.precomp)
self.kd_pre.gen_next_pre()
else:
prf_out = self.prf_legendre_clear_to_ss(m, self.precomp)
return self.prf_helper.bit_compose(prf_out)
def pre_prf_legendre_clear_to_ss(self, msg, r_nonsq):
""" Variant with Key-Dependent Pre-processing """
output_len = self.prf_bit_len
res = [0] * output_len
for i in range(output_len):
sq, bit, prk = self.kd_pre.get_clear_to_ss(where=i)
v = (prk + sq * msg).reveal()
res[i] = (v.legendre() * (2*bit-1) + 1) / 2
return res
def prf_legendre_clear_to_ss(self, msg, r_nonsq):
""" Variant without Key-Dependent Pre-processing """
output_len = self.prf_bit_len
res = [0] * output_len
for i in range(output_len):
s, ssq = sint.get_random_square()
b = sint.get_random_bit()
nonsq = r_nonsq*ssq
t = b*(ssq - nonsq) + nonsq
v = (t * self.key[i] + t * msg).reveal()
res[i] = (v.legendre() * (2*b - 1) + 1) / 2
return res
def encrypt_ss_to_clear(self, m):
""" Takes as input [m], returns c """
prf_out = self.prf_legendre_ss_to_clear(m)
return self.prf_helper.bit_compose(prf_out)
def pre_prf_legendre_ss_to_clear(self, msg):
return self.prf_legendre_ss_to_clear(msg)
def prf_legendre_ss_to_clear(self, msg):
""" Variant with / without Key-Dependent Pre-processing """
output_len = self.prf_bit_len
res = [0] * output_len
for i in range(output_len):
s, ssq = sint.get_random_square()
v = (ssq * (self.key[i] + msg)).reveal()
res[i] = (v.legendre() + 1) / 2
return res
def encrypt_clear_to_clear(self, m):
""" Takes as input m, returns c """
prf_out = self.prf_legendre_clear_to_clear(m)
return self.prf_helper.bit_compose(prf_out)
def prf_legendre_clear_to_clear(self, msg):
""" Variant with / without Key-Dependent Pre-processing """
output_len = self.prf_bit_len
res = [0] * output_len
for i in range(output_len):
s, ssq = sint.get_random_square()
v = (ssq * self.key[i] + ssq * msg).reveal()
res[i] = (v.legendre() + 1) / 2
return res