Add verify tag and zk to aead (#390)

* add verify tag and zk to aead

* move local decryption into stream cipher

* fix arg name

* return plaintext

* truncate tag in zk methods
This commit is contained in:
sinu.eth
2023-12-29 14:06:16 -08:00
committed by GitHub
parent 3cb59ebf83
commit 9169088258
6 changed files with 268 additions and 30 deletions

View File

@@ -180,6 +180,20 @@ impl Aead for MpcAesGcm {
Ok(())
}
async fn decode_key_private(&mut self) -> Result<(), AeadError> {
self.aes_ctr
.decode_key_private()
.await
.map_err(AeadError::from)
}
async fn decode_key_blind(&mut self) -> Result<(), AeadError> {
self.aes_ctr
.decode_key_blind()
.await
.map_err(AeadError::from)
}
fn set_transcript_id(&mut self, id: &str) {
self.aes_ctr.set_transcript_id(id)
}
@@ -321,6 +335,52 @@ impl Aead for MpcAesGcm {
.map_err(AeadError::from)
.await
}
async fn verify_tag(
&mut self,
explicit_nonce: Vec<u8>,
mut ciphertext: Vec<u8>,
aad: Vec<u8>,
) -> Result<(), AeadError> {
let purported_tag = ciphertext.split_off(ciphertext.len() - AES_GCM_TAG_LEN);
let tag = self
.compute_tag(explicit_nonce.clone(), ciphertext, aad)
.await?;
// Reject if tag is incorrect
if tag == purported_tag {
Ok(())
} else {
Err(AeadError::CorruptedTag)
}
}
async fn prove_plaintext(
&mut self,
explicit_nonce: Vec<u8>,
mut ciphertext: Vec<u8>,
) -> Result<Vec<u8>, AeadError> {
ciphertext.truncate(ciphertext.len() - AES_GCM_TAG_LEN);
self.aes_ctr
.prove_plaintext(explicit_nonce, ciphertext)
.map_err(AeadError::from)
.await
}
async fn verify_plaintext(
&mut self,
explicit_nonce: Vec<u8>,
mut ciphertext: Vec<u8>,
) -> Result<(), AeadError> {
ciphertext.truncate(ciphertext.len() - AES_GCM_TAG_LEN);
self.aes_ctr
.verify_plaintext(explicit_nonce, ciphertext)
.map_err(AeadError::from)
.await
}
}
#[cfg(test)]
@@ -580,4 +640,37 @@ mod tests {
.unwrap_err();
assert!(matches!(err, AeadError::CorruptedTag));
}
#[tokio::test]
async fn test_aes_gcm_verify_tag() {
let key = vec![0u8; 16];
let iv = vec![0u8; 4];
let explicit_nonce = vec![0u8; 8];
let plaintext = vec![1u8; 32];
let aad = vec![2u8; 12];
let ciphertext = reference_impl(&key, &iv, &explicit_nonce, &plaintext, &aad);
let len = ciphertext.len();
let ((mut leader, mut follower), (_leader_vm, _follower_vm)) =
setup_pair(key.clone(), iv.clone()).await;
tokio::try_join!(
leader.verify_tag(explicit_nonce.clone(), ciphertext.clone(), aad.clone()),
follower.verify_tag(explicit_nonce.clone(), ciphertext.clone(), aad.clone())
)
.unwrap();
// corrupt tag
let mut corrupted = ciphertext.clone();
corrupted[len - 1] -= 1;
let (leader_res, follower_res) = tokio::join!(
leader.verify_tag(explicit_nonce.clone(), corrupted.clone(), aad.clone()),
follower.verify_tag(explicit_nonce.clone(), corrupted, aad.clone())
);
assert!(matches!(leader_res.unwrap_err(), AeadError::CorruptedTag));
assert!(matches!(follower_res.unwrap_err(), AeadError::CorruptedTag));
}
}

View File

@@ -49,6 +49,12 @@ pub trait Aead: Send {
/// Sets the key for the AEAD.
async fn set_key(&mut self, key: ValueRef, iv: ValueRef) -> Result<(), AeadError>;
/// Decodes the key for the AEAD, revealing it to this party.
async fn decode_key_private(&mut self) -> Result<(), AeadError>;
/// Decodes the key for the AEAD, revealing it to the other party(s).
async fn decode_key_blind(&mut self) -> Result<(), AeadError>;
/// Sets the transcript id
///
/// The AEAD assigns unique identifiers to each byte of plaintext
@@ -144,4 +150,48 @@ pub trait Aead: Send {
ciphertext: Vec<u8>,
aad: Vec<u8>,
) -> Result<(), AeadError>;
/// Verifies the tag of a ciphertext message.
///
/// This method checks the authenticity of the ciphertext, tag and additional data.
///
/// * `explicit_nonce` - The explicit nonce to use for decryption.
/// * `ciphertext` - The ciphertext and tag to authenticate and decrypt.
/// * `aad` - Additional authenticated data.
async fn verify_tag(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
aad: Vec<u8>,
) -> Result<(), AeadError>;
/// Locally decrypts the provided ciphertext and then proves in ZK to the other party(s) that the
/// plaintext is correct.
///
/// Returns the plaintext.
///
/// This method requires this party to know the encryption key, which can be achieved by calling
/// the `decode_key_private` method.
///
/// # Arguments
///
/// * `explicit_nonce`: The explicit nonce to use for the keystream.
/// * `ciphertext`: The ciphertext to decrypt and prove.
async fn prove_plaintext(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
) -> Result<Vec<u8>, AeadError>;
/// Verifies the other party(s) can prove they know a plaintext which encrypts to the given ciphertext.
///
/// # Arguments
///
/// * `explicit_nonce`: The explicit nonce to use for the keystream.
/// * `ciphertext`: The ciphertext to verify.
async fn verify_plaintext(
&mut self,
explicit_nonce: Vec<u8>,
ciphertext: Vec<u8>,
) -> Result<(), AeadError>;
}

View File

@@ -106,7 +106,7 @@ async fn bench_stream_cipher_zk(thread_count: usize, len: usize) {
let plaintext = vec![0u8; len];
let explicit_nonce = [0u8; 8];
let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 2, &explicit_nonce, &plaintext);
let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 2, &explicit_nonce, &plaintext).unwrap();
_ = tokio::try_join!(
leader.prove_plaintext(explicit_nonce.to_vec(), plaintext),

View File

@@ -5,12 +5,12 @@ use mpz_circuits::{
Circuit,
};
use crate::circuit::AES_CTR;
use crate::{circuit::AES_CTR, StreamCipherError};
/// A counter-mode block cipher circuit.
pub trait CtrCircuit: Default + Clone + Send + Sync + 'static {
/// The key type
type KEY: StaticValueType + Send + Sync + 'static;
type KEY: StaticValueType + TryFrom<Vec<u8>> + Send + Sync + 'static;
/// The block type
type BLOCK: StaticValueType
+ TryFrom<Vec<u8>>
@@ -54,12 +54,12 @@ pub trait CtrCircuit: Default + Clone + Send + Sync + 'static {
/// Applies the keystream to the message
fn apply_keystream(
key: &Self::KEY,
iv: &Self::IV,
key: &[u8],
iv: &[u8],
start_ctr: usize,
explicit_nonce: &Self::NONCE,
explicit_nonce: &[u8],
msg: &[u8],
) -> Vec<u8>;
) -> Result<Vec<u8>, StreamCipherError>;
}
/// A circuit for AES-128 in counter mode.
@@ -82,16 +82,35 @@ impl CtrCircuit for Aes128Ctr {
}
fn apply_keystream(
key: &Self::KEY,
iv: &Self::IV,
key: &[u8],
iv: &[u8],
start_ctr: usize,
explicit_nonce: &Self::NONCE,
explicit_nonce: &[u8],
msg: &[u8],
) -> Vec<u8> {
) -> Result<Vec<u8>, StreamCipherError> {
use ::cipher::{KeyIvInit, StreamCipher, StreamCipherSeek};
use aes::Aes128;
use ctr::Ctr32BE;
let key: &[u8; 16] = key
.try_into()
.map_err(|_| StreamCipherError::InvalidKeyLength {
expected: 16,
actual: key.len(),
})?;
let iv: &[u8; 4] = iv
.try_into()
.map_err(|_| StreamCipherError::InvalidIvLength {
expected: 4,
actual: iv.len(),
})?;
let explicit_nonce: &[u8; 8] = explicit_nonce.try_into().map_err(|_| {
StreamCipherError::InvalidExplicitNonceLength {
expected: 8,
actual: explicit_nonce.len(),
}
})?;
let mut full_iv = [0u8; 16];
full_iv[0..4].copy_from_slice(iv);
full_iv[4..12].copy_from_slice(explicit_nonce);
@@ -103,6 +122,6 @@ impl CtrCircuit for Aes128Ctr {
.expect("start counter is less than keystream length");
cipher.apply_keystream(&mut buf);
buf
Ok(buf)
}
}

View File

@@ -42,6 +42,10 @@ pub enum StreamCipherError {
VerifyError(#[from] mpz_garble::VerifyError),
#[error("key and iv is not set")]
KeyIvNotSet,
#[error("invalid key length: expected {expected}, got {actual}")]
InvalidKeyLength { expected: usize, actual: usize },
#[error("invalid iv length: expected {expected}, got {actual}")]
InvalidIvLength { expected: usize, actual: usize },
#[error("invalid explicit nonce length: expected {expected}, got {actual}")]
InvalidExplicitNonceLength { expected: usize, actual: usize },
#[error("missing value for {0}")]
@@ -57,6 +61,12 @@ where
/// Sets the key and iv for the stream cipher.
fn set_key(&mut self, key: ValueRef, iv: ValueRef);
/// Decodes the key for the stream cipher, revealing it to this party.
async fn decode_key_private(&mut self) -> Result<(), StreamCipherError>;
/// Decodes the key for the stream cipher, revealing it to the other party(s).
async fn decode_key_blind(&mut self) -> Result<(), StreamCipherError>;
/// Sets the transcript id
///
/// The stream cipher assigns unique identifiers to each byte of plaintext
@@ -149,17 +159,23 @@ where
ciphertext: Vec<u8>,
) -> Result<(), StreamCipherError>;
/// Privately proves to the other party(s) the plaintext encrypts to a certain ciphertext.
/// Locally decrypts the provided ciphertext and then proves in ZK to the other party(s) that the
/// plaintext is correct.
///
/// Returns the plaintext.
///
/// This method requires this party to know the encryption key, which can be achieved by calling
/// the `decode_key_private` method.
///
/// # Arguments
///
/// * `explicit_nonce`: The explicit nonce to use for the keystream.
/// * `plaintext`: The plaintext to prove.
/// * `ciphertext`: The ciphertext to decrypt and prove.
async fn prove_plaintext(
&mut self,
explicit_nonce: Vec<u8>,
plaintext: Vec<u8>,
) -> Result<(), StreamCipherError>;
ciphertext: Vec<u8>,
) -> Result<Vec<u8>, StreamCipherError>;
/// Verifies the other party(s) can prove they know a plaintext which encrypts to the given ciphertext.
///
@@ -306,7 +322,7 @@ mod tests {
(follower_encrypted_msg, follower_decrypted_msg),
) = futures::join!(leader_fut, follower_fut);
let reference = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &msg);
let reference = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &msg).unwrap();
assert_eq!(leader_encrypted_msg, reference);
assert_eq!(leader_decrypted_msg, msg);
@@ -324,7 +340,7 @@ mod tests {
let msg = b"This is a test message which will be encrypted using AES-CTR.".to_vec();
let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &msg);
let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &msg).unwrap();
let ((mut leader, mut follower), (mut leader_vm, mut follower_vm)) =
create_test_pair::<Aes128Ctr>(1, key, iv, 8).await;
@@ -398,7 +414,8 @@ mod tests {
.map(|(a, b)| a ^ b)
.collect::<Vec<u8>>();
let reference = Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &[0u8; 16]);
let reference =
Aes128Ctr::apply_keystream(&key, &iv, 1, &explicit_nonce, &[0u8; 16]).unwrap();
assert_eq!(reference, key_block);
}
@@ -413,13 +430,15 @@ mod tests {
let msg = b"This is a test message which will be encrypted using AES-CTR.".to_vec();
let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 2, &explicit_nonce, &msg);
let ciphertext = Aes128Ctr::apply_keystream(&key, &iv, 2, &explicit_nonce, &msg).unwrap();
let ((mut leader, mut follower), (mut leader_vm, mut follower_vm)) =
create_test_pair::<Aes128Ctr>(2, key, iv, 8).await;
futures::try_join!(leader.decode_key_private(), follower.decode_key_blind()).unwrap();
futures::try_join!(
leader.prove_plaintext(explicit_nonce.to_vec(), msg),
leader.prove_plaintext(explicit_nonce.to_vec(), ciphertext.clone()),
follower.verify_plaintext(explicit_nonce.to_vec(), ciphertext)
)
.unwrap();

View File

@@ -28,6 +28,8 @@ where
}
struct State {
/// Encoded key and IV for the cipher.
encoded_key_iv: Option<EncodedKeyAndIv>,
/// Key and IV for the cipher.
key_iv: Option<KeyAndIv>,
/// Unique identifier for each execution of the cipher.
@@ -41,11 +43,17 @@ struct State {
}
#[derive(Clone)]
struct KeyAndIv {
struct EncodedKeyAndIv {
key: ValueRef,
iv: ValueRef,
}
#[derive(Clone)]
struct KeyAndIv {
key: Vec<u8>,
iv: Vec<u8>,
}
impl<C, E> MpcStreamCipher<C, E>
where
C: CtrCircuit,
@@ -60,6 +68,7 @@ where
Self {
config,
state: State {
encoded_key_iv: None,
key_iv: None,
execution_id,
transcript_counter,
@@ -102,9 +111,9 @@ where
len: usize,
mode: ExecutionMode,
) -> Result<ValueRef, StreamCipherError> {
let KeyAndIv { key, iv } = self
let EncodedKeyAndIv { key, iv } = self
.state
.key_iv
.encoded_key_iv
.clone()
.ok_or(StreamCipherError::KeyIvNotSet)?;
@@ -218,7 +227,41 @@ where
E: Thread + Execute + Prove + Verify + Decode + DecodePrivate + Send + Sync + 'static,
{
fn set_key(&mut self, key: ValueRef, iv: ValueRef) {
self.state.encoded_key_iv = Some(EncodedKeyAndIv { key, iv });
}
async fn decode_key_private(&mut self) -> Result<(), StreamCipherError> {
let EncodedKeyAndIv { key, iv } = self
.state
.encoded_key_iv
.clone()
.ok_or(StreamCipherError::KeyIvNotSet)?;
let mut scope = self.thread_pool.new_scope();
scope.push(move |thread| Box::pin(async move { thread.decode_private(&[key, iv]).await }));
let output = scope.wait().await.into_iter().next().unwrap()?;
let [key, iv]: [_; 2] = output.try_into().expect("decoded 2 values");
let key: Vec<u8> = key.try_into().expect("key is an array");
let iv: Vec<u8> = iv.try_into().expect("iv is an array");
self.state.key_iv = Some(KeyAndIv { key, iv });
Ok(())
}
async fn decode_key_blind(&mut self) -> Result<(), StreamCipherError> {
let EncodedKeyAndIv { key, iv } = self
.state
.encoded_key_iv
.clone()
.ok_or(StreamCipherError::KeyIvNotSet)?;
let mut scope = self.thread_pool.new_scope();
scope.push(move |thread| Box::pin(async move { thread.decode_blind(&[key, iv]).await }));
scope.wait().await.into_iter().next().unwrap()?;
Ok(())
}
fn set_transcript_id(&mut self, id: &str) {
@@ -480,8 +523,22 @@ where
async fn prove_plaintext(
&mut self,
explicit_nonce: Vec<u8>,
plaintext: Vec<u8>,
) -> Result<(), StreamCipherError> {
ciphertext: Vec<u8>,
) -> Result<Vec<u8>, StreamCipherError> {
let KeyAndIv { key, iv } = self
.state
.key_iv
.clone()
.ok_or(StreamCipherError::KeyIvNotSet)?;
let plaintext = C::apply_keystream(
&key,
&iv,
self.config.start_ctr,
&explicit_nonce,
&ciphertext,
)?;
// Prove plaintext encrypts back to ciphertext
let keystream = self
.compute_keystream(
@@ -497,7 +554,7 @@ where
.apply_keystream(
InputText::Private {
ids: plaintext_ids,
text: plaintext,
text: plaintext.clone(),
},
keystream,
ExecutionMode::Prove,
@@ -506,7 +563,7 @@ where
self.prove(ciphertext).await?;
Ok(())
Ok(plaintext)
}
async fn verify_plaintext(
@@ -546,9 +603,9 @@ where
explicit_nonce: Vec<u8>,
ctr: usize,
) -> Result<Vec<u8>, StreamCipherError> {
let KeyAndIv { key, iv } = self
let EncodedKeyAndIv { key, iv } = self
.state
.key_iv
.encoded_key_iv
.clone()
.ok_or(StreamCipherError::KeyIvNotSet)?;