from Compiler import instructions_base import sys program.bit_length = 128 nparallel = int(sys.argv[2]) instructions_base.set_global_vector_size(nparallel) PREPROC = True # To disable key-dependent pre-processing set PREPROC=False class KDPreprocessing(object): """ Generate Key-Dependent Pre-processing """ def __init__(self, num_calls, keys, nonsq, bit_len): # legendre has a key for each bit self.bit_len = bit_len pre_squares = VectorArray(num_calls * bit_len, sint, nparallel) pre_bits = VectorArray(num_calls * bit_len, sint, nparallel) pre_keys = VectorArray(num_calls * bit_len, sint, nparallel) self.num_calls = num_calls self.index = MemValue(regint(0)) # for each call generate pre-processing data @for_range(num_calls) def block(row_index): cur_row = row_index * bit_len for i in range(bit_len): s, ssq = sint.get_random_square() b = sint.get_random_bit() t = ssq * (b + nonsq * (1 - b)) next_row = cur_row + i pre_squares[next_row] = t pre_bits[next_row] = b pre_keys[next_row] = t * keys[i] # [r^2], [b], [k_i * r^2] self.pre_squares = pre_squares self.pre_bits = pre_bits self.pre_keys = pre_keys def get_ss_to_ss(self, where): index = self.index target = index * self.bit_len + where s, b = self.pre_squares[target], self.pre_bits[target] return s, b def get_clear_to_ss(self, where): index = self.index target = index * self.bit_len + where s, b, k = self.pre_squares[target], self.pre_bits[target], self.pre_keys[target] return s, b, k def gen_next_pre(self): self.index.iadd(1) def reset(self): self.index.imul(0) class LegPRF(object): class PRFHelper(object): """ Helper functions, independent of the PRF key """ def __init__(self, noutput): self.twos = [cint(2)**i for i in range(noutput)] self.noutput = noutput def bit_compose(self, prf_out): return sum(self.twos[i] * prf_out[i] for i in range(self.noutput)) def __init__(self, num_calls, bit_len): self.prf_bit_len = bit_len self.prf_helper = LegPRF.PRFHelper(bit_len) # quadratic non-residue modulo p (\alpha in paper) self.precomp = cint(3482426629118966072265963355631718388) self.key = Array(bit_len, sint) @for_range(bit_len) def block(i): self.key[i] = sint.get_random_triple()[0] if len(self.key) != self.prf_bit_len: raise CompilerError('Invalid key length to Legendre PRF') nonsq = self.precomp self.kd_pre = KDPreprocessing(num_calls, self.key, nonsq, self.prf_bit_len) def reset_kd_pre(self): self.kd_pre.reset() def encrypt_ss_to_ss(self, m): """ Takes as input [m], returns [c] """ if PREPROC: prf_out = self.pre_prf_legendre_ss_to_ss(m, self.precomp) self.kd_pre.gen_next_pre() else: prf_out = self.prf_legendre_ss_to_ss(m, self.precomp) return self.prf_helper.bit_compose(prf_out) def pre_prf_legendre_ss_to_ss(self, msg, r_nonsq): """ Variant with Key-Dependent Pre-processing """ output_len = self.prf_bit_len res = [0] * output_len for i in range(output_len): sq, bit = self.kd_pre.get_ss_to_ss(where=i) v = (sq * (self.key[i] + msg)).reveal() res[i] = (v.legendre() * (2*bit-1) + 1) / 2 return res def prf_legendre_ss_to_ss(self, msg, r_nonsq): """ Variant without Key-Dependent Pre-processing """ output_len = self.prf_bit_len res = [0] * output_len for i in range(output_len): s, ssq = sint.get_random_square() b = sint.get_random_bit() nonsq = r_nonsq*ssq t = b*(ssq - nonsq) + nonsq v = (t * (self.key[i] + msg)).reveal() res[i] = (v.legendre() * (2*b - 1) + 1) / 2 return res def encrypt_clear_to_ss(self, m): """ Takes as input m, returns [c] """ if PREPROC: prf_out = self.pre_prf_legendre_clear_to_ss(m, self.precomp) self.kd_pre.gen_next_pre() else: prf_out = self.prf_legendre_clear_to_ss(m, self.precomp) return self.prf_helper.bit_compose(prf_out) def pre_prf_legendre_clear_to_ss(self, msg, r_nonsq): """ Variant with Key-Dependent Pre-processing """ output_len = self.prf_bit_len res = [0] * output_len for i in range(output_len): sq, bit, prk = self.kd_pre.get_clear_to_ss(where=i) v = (prk + sq * msg).reveal() res[i] = (v.legendre() * (2*bit-1) + 1) / 2 return res def prf_legendre_clear_to_ss(self, msg, r_nonsq): """ Variant without Key-Dependent Pre-processing """ output_len = self.prf_bit_len res = [0] * output_len for i in range(output_len): s, ssq = sint.get_random_square() b = sint.get_random_bit() nonsq = r_nonsq*ssq t = b*(ssq - nonsq) + nonsq v = (t * self.key[i] + t * msg).reveal() res[i] = (v.legendre() * (2*b - 1) + 1) / 2 return res def encrypt_ss_to_clear(self, m): """ Takes as input [m], returns c """ prf_out = self.prf_legendre_ss_to_clear(m) return self.prf_helper.bit_compose(prf_out) def pre_prf_legendre_ss_to_clear(self, msg): return self.prf_legendre_ss_to_clear(msg) def prf_legendre_ss_to_clear(self, msg): """ Variant with / without Key-Dependent Pre-processing """ output_len = self.prf_bit_len res = [0] * output_len for i in range(output_len): s, ssq = sint.get_random_square() v = (ssq * (self.key[i] + msg)).reveal() res[i] = (v.legendre() + 1) / 2 return res def encrypt_clear_to_clear(self, m): """ Takes as input m, returns c """ prf_out = self.prf_legendre_clear_to_clear(m) return self.prf_helper.bit_compose(prf_out) def prf_legendre_clear_to_clear(self, msg): """ Variant with / without Key-Dependent Pre-processing """ output_len = self.prf_bit_len res = [0] * output_len for i in range(output_len): s, ssq = sint.get_random_square() v = (ssq * self.key[i] + ssq * msg).reveal() res[i] = (v.legendre() + 1) / 2 return res