diff --git a/mpc/garble/mpc-garble/src/protocol/deap/error.rs b/mpc/garble/mpc-garble/src/protocol/deap/error.rs index e74b73f51..2d7f8617d 100644 --- a/mpc/garble/mpc-garble/src/protocol/deap/error.rs +++ b/mpc/garble/mpc-garble/src/protocol/deap/error.rs @@ -41,6 +41,18 @@ pub enum FinalizationError { InvalidProof, } +/// Errors that can occur when accessing peer's encodings. +#[derive(Debug, thiserror::Error)] +#[allow(missing_docs)] +pub enum PeerEncodingsError { + #[error("Encodings not available since DEAP instance already finalized")] + AlreadyFinalized, + #[error("Value id was not found in registry: {0:?}")] + ValueIdNotFound(String), + #[error("Encoding is not available for value: {0:?}")] + EncodingNotAvailable(ValueRef), +} + impl From for ExecutionError { fn from(err: DEAPError) -> Self { match err { diff --git a/mpc/garble/mpc-garble/src/protocol/deap/mod.rs b/mpc/garble/mpc-garble/src/protocol/deap/mod.rs index 99f299cc6..4a198a20e 100644 --- a/mpc/garble/mpc-garble/src/protocol/deap/mod.rs +++ b/mpc/garble/mpc-garble/src/protocol/deap/mod.rs @@ -33,7 +33,7 @@ use crate::{ registry::ValueRegistry, }; -pub use error::DEAPError; +pub use error::{DEAPError, PeerEncodingsError}; pub use vm::{DEAPThread, DEAPVm}; use self::error::FinalizationError; @@ -870,6 +870,11 @@ impl DEAP { Ok(()) } + + // Returns a reference to the evaluator + pub(crate) fn ev(&self) -> &Evaluator { + &self.ev + } } impl State { diff --git a/mpc/garble/mpc-garble/src/protocol/deap/vm.rs b/mpc/garble/mpc-garble/src/protocol/deap/vm.rs index af55683d6..f9c7c8817 100644 --- a/mpc/garble/mpc-garble/src/protocol/deap/vm.rs +++ b/mpc/garble/mpc-garble/src/protocol/deap/vm.rs @@ -14,7 +14,7 @@ use mpc_circuits::{ Circuit, }; use mpc_core::value::ValueRef; -use mpc_garble_core::msg::GarbleMessage; +use mpc_garble_core::{encoding_state::Active, msg::GarbleMessage, EncodedValue}; use utils::id::NestedId; use utils_aio::{mux::MuxChannelControl, Channel}; @@ -25,7 +25,10 @@ use crate::{ ProveError, Thread, Verify, VerifyError, Vm, VmError, }; -use super::{error::FinalizationError, DEAPError, DEAP}; +use super::{ + error::{FinalizationError, PeerEncodingsError}, + DEAPError, DEAP, +}; type ChannelFactory = Box + Send + 'static>; type GarbleChannel = Box>; @@ -430,6 +433,48 @@ where } } +/// This trait provides methods to get peer's encodings. +trait PeerEncodings { + /// Returns the peer's encodings of the provided **input** values. + /// + /// # Errors + /// + /// Returns an error if the input value is not found or its encoding is not available. + fn get_peer_encodings( + &self, + value_ids: &[&str], + ) -> Result>, PeerEncodingsError>; +} + +impl PeerEncodings for DEAPVm { + fn get_peer_encodings( + &self, + value_ids: &[&str], + ) -> Result>, PeerEncodingsError> { + if self.finalized { + return Err(PeerEncodingsError::AlreadyFinalized)?; + } + + let deap = self.deap.as_ref().expect("instance set until finalization"); + + Ok(value_ids + .iter() + .map(|id| { + // get reference by id + let value_ref = match deap.get_value(id) { + Some(v) => v, + None => return Err(PeerEncodingsError::ValueIdNotFound(id.to_string())), + }; + // get encoding by reference + match deap.ev().get_encoding(&value_ref) { + Some(e) => Ok(e), + None => return Err(PeerEncodingsError::EncodingNotAvailable(value_ref)), + } + }) + .collect::, PeerEncodingsError>>()?) + } +} + #[cfg(test)] mod tests { use super::*; @@ -500,9 +545,22 @@ mod tests { assert_eq!(leader_result, follower_result); + // These encodings should be available + assert!(leader_vm.get_peer_encodings(&["msg", "ciphertext"]).is_ok()); + + // A non-existent value id will cause an error + let err = leader_vm + .get_peer_encodings(&["msg", "random_id"]) + .unwrap_err(); + assert!(matches!(err, PeerEncodingsError::ValueIdNotFound(_))); + let (leader_result, follower_result) = futures::join!(leader_vm.finalize(), follower_vm.finalize()); + // Trying to get encodings after finalization will cause an error + let err = leader_vm.get_peer_encodings(&["msg"]).unwrap_err(); + assert!(matches!(err, PeerEncodingsError::AlreadyFinalized)); + leader_result.unwrap(); follower_result.unwrap(); }