diff --git a/src/raft/mod.rs b/src/raft/mod.rs index 16a8e7ffa..a13710f3b 100644 --- a/src/raft/mod.rs +++ b/src/raft/mod.rs @@ -45,6 +45,9 @@ pub struct LogRequest { suffix: Logs, } +#[derive(SerialDecodable, SerialEncodable, Clone, Debug)] +pub struct BroadcastMsgRequest(Vec); + #[derive(SerialDecodable, SerialEncodable, Clone, Debug)] pub struct LogResponse { node_id: NodeId, @@ -119,6 +122,7 @@ pub enum NetMsgMethod { LogRequest = 1, VoteResponse = 2, VoteRequest = 3, + BroadcastRequest = 4, } impl Encodable for NetMsgMethod { @@ -128,6 +132,7 @@ impl Encodable for NetMsgMethod { Self::LogRequest => 1, Self::VoteResponse => 2, Self::VoteRequest => 3, + Self::BroadcastRequest => 4, }; (len as u8).encode(s) } @@ -140,7 +145,8 @@ impl Decodable for NetMsgMethod { 0 => Self::LogResponse, 1 => Self::LogRequest, 2 => Self::VoteResponse, - _ => Self::VoteRequest, + 3 => Self::VoteRequest, + _ => Self::BroadcastRequest, }) } } diff --git a/src/raft/p2p.rs b/src/raft/p2p.rs index fa66d9650..d7319d4df 100644 --- a/src/raft/p2p.rs +++ b/src/raft/p2p.rs @@ -9,7 +9,7 @@ use crate::{net, Result}; use super::{NetMsg, NodeId}; pub struct ProtocolRaft { - id: NodeId, + id: Option, jobsman: net::ProtocolJobsManagerPtr, notify_queue_sender: async_channel::Sender, msg_sub: net::MessageSubscription, @@ -19,7 +19,7 @@ pub struct ProtocolRaft { impl ProtocolRaft { pub async fn init( - id: NodeId, + id: Option, channel: net::ChannelPtr, notify_queue_sender: async_channel::Sender, p2p: net::P2pPtr, @@ -57,8 +57,8 @@ impl ProtocolRaft { let msg = (*msg).clone(); self.p2p.broadcast(msg.clone()).await?; - if let Some(rec_id) = msg.recipient_id.clone() { - if rec_id != self.id { + if msg.recipient_id.is_some() && self.id.is_some() { + if msg.recipient_id != self.id { continue } } diff --git a/src/raft/raft.rs b/src/raft/raft.rs index f2ed8ebcd..ccb984fee 100644 --- a/src/raft/raft.rs +++ b/src/raft/raft.rs @@ -16,8 +16,8 @@ use crate::{ }; use super::{ - DataStore, Log, LogRequest, LogResponse, Logs, NetMsg, NetMsgMethod, NodeId, ProtocolRaft, - Role, VoteRequest, VoteResponse, + BroadcastMsgRequest, DataStore, Log, LogRequest, LogResponse, Logs, NetMsg, NetMsgMethod, + NodeId, ProtocolRaft, Role, VoteRequest, VoteResponse, }; const HEARTBEATTIMEOUT: u64 = 100; @@ -29,7 +29,9 @@ type Sender = (async_channel::Sender, async_channel::Receiver); pub struct Raft { // this will be derived from the ip - id: NodeId, + // if the node doesn't have an id then will become a listener and doesn't have the right + // to request/response votes or response a confirmation for log + id: Option, // these five vars should be on local storage current_term: u64, @@ -60,7 +62,7 @@ pub struct Raft { } impl Raft { - pub fn new(addr: SocketAddr, db_path: PathBuf) -> Result { + pub fn new(addr: Option, db_path: PathBuf) -> Result { if db_path.to_str().is_none() { error!(target: "raft", "datastore path is incorrect"); return Err(Error::ParseFailed("unable to parse pathbuf to str")) @@ -90,7 +92,7 @@ impl Raft { let sender = async_channel::unbounded::(); Ok(Self { - id: NodeId::from(addr), + id: addr.map(NodeId::from), current_term, voted_for, logs, @@ -199,12 +201,20 @@ impl Raft { let log = Log { msg, term: self.current_term }; self.push_log(&log)?; - self.acked_length.insert(self.id.clone(), self.logs.len()); + self.acked_length.insert(self.id.clone().unwrap(), self.logs.len()); let nodes = self.nodes.lock().await.clone(); for node in nodes.iter() { self.update_logs(node.0).await?; } + } else { + let b_msg = BroadcastMsgRequest(serialize(msg)); + self.send( + self.current_leader.clone(), + &serialize(&b_msg), + NetMsgMethod::BroadcastRequest, + ) + .await?; } Ok(()) } @@ -227,6 +237,11 @@ impl Raft { let vr: VoteRequest = deserialize(&msg.payload)?; self.receive_vote_request(vr).await?; } + NetMsgMethod::BroadcastRequest => { + let vr: BroadcastMsgRequest = deserialize(&msg.payload)?; + let d: T = deserialize(&vr.0)?; + self.broadcast_msg(&d).await?; + } } Ok(()) } @@ -253,15 +268,22 @@ impl Raft { } async fn send_vote_request(&mut self) -> Result<()> { + // this will prevent the node to become a candidate + if self.id.is_none() { + return Ok(()) + } + + let self_id = self.id.clone().unwrap(); + self.set_current_term(&(self.current_term + 1))?; self.role = Role::Candidate; - self.set_voted_for(&Some(self.id.clone()))?; - self.votes_received.push(self.id.clone()); + self.set_voted_for(&Some(self_id.clone()))?; + self.votes_received.push(self_id.clone()); self.reset_last_term(); let request = VoteRequest { - node_id: self.id.clone(), + node_id: self_id, current_term: self.current_term, log_length: self.logs.len(), last_term: self.last_term, @@ -272,6 +294,10 @@ impl Raft { } async fn receive_vote_request(&mut self, vr: VoteRequest) -> Result<()> { + if self.id.is_none() { + return Ok(()) + } + if vr.current_term > self.current_term { self.set_current_term(&vr.current_term)?; self.set_voted_for(&None)?; @@ -291,8 +317,11 @@ impl Raft { true }; - let mut response = - VoteResponse { node_id: self.id.clone(), current_term: self.current_term, ok: false }; + let mut response = VoteResponse { + node_id: self.id.clone().unwrap(), + current_term: self.current_term, + ok: false, + }; if vr.current_term == self.current_term && vote_ok && vote { self.set_voted_for(&Some(vr.node_id.clone()))?; @@ -310,7 +339,7 @@ impl Raft { let nodes = self.nodes.lock().await; if self.votes_received.len() >= ((nodes.len() + 1) / 2) { self.role = Role::Leader; - self.current_leader = Some(self.id.clone()); + self.current_leader = Some(self.id.clone().unwrap()); for node in nodes.iter() { self.sent_length.insert(node.0.clone(), self.logs.len()); self.acked_length.insert(node.0.clone(), 0); @@ -337,7 +366,7 @@ impl Raft { } let request = LogRequest { - leader_id: self.id.clone(), + leader_id: self.id.clone().unwrap(), current_term: self.current_term, prefix_len, prefix_term, @@ -363,17 +392,22 @@ impl Raft { let ok = (self.logs.len() >= lr.prefix_len) && (lr.prefix_len == 0 || self.logs.get(lr.prefix_len - 1).term == lr.prefix_term); - let response: LogResponse = if lr.current_term == self.current_term && ok { + let mut ack = 0; + + if lr.current_term == self.current_term && ok { self.append_log(lr.prefix_len, lr.commit_length, &lr.suffix).await?; - let ack = lr.prefix_len + lr.suffix.len(); - LogResponse { node_id: self.id.clone(), current_term: self.current_term, ack, ok } - } else { - LogResponse { - node_id: self.id.clone(), - current_term: self.current_term, - ack: 0, - ok: false, - } + ack = lr.prefix_len + lr.suffix.len(); + } + + if self.id.is_none() { + return Ok(()) + } + + let response = LogResponse { + node_id: self.id.clone().unwrap(), + current_term: self.current_term, + ack, + ok, }; let payload = serialize(&response);