Files
MP-SPDZ/Programs/Source/htmac.mpc

196 lines
6.5 KiB
Plaintext

""" Execfile can be avoided by implementing this ORAM-style
(put .py sources in Compiler/ directory)
- wanted to keep a clean directory structure hence
used execfile
This is HtMAC construction as specified in our ToSC paper:
https://eprint.iacr.org/2017/496
Line with (set_global_vector_size) says:
'execute the PRF in n_parallel parallel instances'
This unrolls to single instructions which execute n_parallel times
in SPDZ on-line phase. (check the assembly code for more details
by compiling with -a flag)
For single instructions executing one time set n_parallel=1
- test_decryption = False => Compute Enc() + Tag-Gen.
- test_decryption = True => Compute Enc() + Tag-Gen.
Then compute Dec() + Tag-Check.
Howto run?
- Compile it via ./compile.py htmac n_parallel n_total nmessages
Eg: ./compile.py 2 8 10
Following compiles 8 HtMAC instances where 2 is the level of parallelism
and each instance authenticates 10-block messages.
A simplistic view of what is happening:
HtMAC(m) HtMAC(m) (1 run, 2 instances)
HtMAC(m) HtMAC(m) (1 run, 2 instances)
HtMAC(m) HtMAC(m) (1 run, 2 instances)
HtMAC(m) HtMAC(m) (1 run, 2 instances)
where |m| = 10 blocks
- To run the online phase: ./Scripts/run-online htmac-2-8-10
"""
from Compiler.program import Program
from Compiler import instructions_base
from random import randint
import sys
program.bit_length = 128
n_parallel = int(sys.argv[2])
n_total = int(sys.argv[3])
nmessages = int(sys.argv[4])
use_mimc_prf = True
# Use just one PRF
use_leg_prf = 1 - use_mimc_prf
test_decryption = True
instructions_base.set_global_vector_size(n_parallel)
if use_mimc_prf:
exec(compile(__builtins__['open']('./Programs/Source/prf_mimc.mpc').read(), './Programs/Source/prf_mimc.mpc', 'exec'))
elif use_leg_prf:
exec(compile(__builtins__['open']('./Programs/Source/prf_leg.mpc').read(), './Programs/Source/prf_leg.mpc', 'exec'))
class HMAC(object):
def __init__(self, _enc):
self.cipher = _enc
def update_tweak(self, one_mask):
self.one_mask = one_mask
def default_auth(self, ciphertext):
sigma = cint(0)
m = nmessages
E = self.cipher
for i in range(m-1):
sigma += ciphertext[i]
digest = sigma.digest(16) # compute H(sigma)[0..127]
# E^{-1,0}(digest) = E(digest + -1 * E(1))
final_mask = digest + -1 * self.one_mask
return final_mask
def encryption_auth(self, ciphertext):
# tag is clear
final_mask = self.default_auth(ciphertext)
return self.cipher.encrypt_ss_to_clear(final_mask)
def decryption_auth(self, ciphertext):
# tag is shared
final_mask = self.default_auth(ciphertext)
return self.cipher.encrypt_ss_to_ss(final_mask)
class NonceEncryptMAC(object):
def __init__(self, _mac, _enc):
self.mac = _mac
self.enc = _enc
self.one_mask = self.enc.encrypt_clear_to_ss(cint(1)) # E(1)
def get_one_mask(self):
return self.one_mask
def get_long_random(self, nbits):
""" Returns random cint() % 2^{nbits} """
result = cint(0)
for i in range(nbits // 30):
result += cint(regint.get_random(30))
result <<= 30
result += cint(regint.get_random(nbits % 30))
return result
def apply(self, message):
nonce = self.get_long_random(120)
one_mask = self.one_mask
blocks = [None] * nmessages
for ctr in range(nmessages):
# mask \asn E(N + i * E(1))
mask = self.enc.encrypt_ss_to_ss(nonce + (ctr+1)*one_mask)
temp = message[ctr] + mask
blocks[ctr] = temp.reveal()
return nonce, blocks, self.mac.encryption_auth(blocks)
def default_decrypt(self, nonce, ciphertext):
one_mask = self.one_mask
blocks = [None] * nmessages
for ctr in range(nmessages):
mask = self.enc.encrypt_ss_to_ss(nonce + (ctr+1)*one_mask)
blocks[ctr] = ciphertext[ctr] - mask
computed_tag = self.mac.decryption_auth(ciphertext)
return blocks, computed_tag
def decrypt(self, nonce, ciphertext, recv_tag):
message, computed_tag = self.default_decrypt(nonce, ciphertext)
num_random = sint.get_random_triple()[0]
zero_checker = num_random * (recv_tag - computed_tag)
zero_checker = zero_checker.reveal()
@if_(zero_checker != 0)
def f():
print_ln("MAC ERROR. PANIC")
return message, computed_tag
def checker(self, nonce, ciphertext, recv_tag):
message, tag = self.decrypt(nonce, ciphertext, recv_tag)
print_ln('decrypted plaintext:')
for i in range(len(message)):
print_str('%s ', message[i].reveal())
print_ln()
def time_private_mac(n_total, n_parallel, nmessages):
# The following line is to count required off-line data easier
# by measuring amount of triples/bits used between blocks
Program.prog.curr_tape.start_new_basicblock(name='preproc-block')
num_calls = nmessages + 1
# To test decryption we require 2(N+1) calls - (N+1)encrypt + (N+1)decrypt
if test_decryption:
num_calls *= 2
enc_alg = None
# branch on desired PRF
if use_mimc_prf:
key = sint.get_random_int(128)
enc_alg = MiMC(73, key, num_calls)
elif use_leg_prf:
enc_alg = LegPRF(num_calls, bit_len=128)
Program.prog.curr_tape.start_new_basicblock(name='preproc-block2')
# Create HtMAC scheme
auth_hmac = HMAC(enc_alg)
auth_scheme = NonceEncryptMAC(auth_hmac, enc_alg)
# Update tweak with E_k(1)
auth_hmac.update_tweak(auth_scheme.get_one_mask())
Program.prog.curr_tape.start_new_basicblock(name='online-block')
# Benchmark n_total HtMAC's while executing in parallel n_parallel
start_timer(1)
@for_range(n_total // n_parallel)
def block(index):
# Re-use off-line data after n_parallel runs for benchmarking purposes.
# If real system-use need to initialize num_calls with a larger constant.
enc_alg.reset_kd_pre()
message = [sint(2*i+1) for i in range(nmessages)]
iv, ciphertext, tag = auth_scheme.apply(message)
if test_decryption:
auth_scheme.checker(iv, ciphertext, tag)
stop_timer(1)
print_ln('##############################################')
time_private_mac(n_total, n_parallel, nmessages)
print_ln('##############################################')