raft: prevent a listener node from request/response votes

This commit is contained in:
ghassmo
2022-04-09 09:12:56 +04:00
parent ba355bc95c
commit 277e908865
3 changed files with 68 additions and 28 deletions

View File

@@ -45,6 +45,9 @@ pub struct LogRequest {
suffix: Logs,
}
#[derive(SerialDecodable, SerialEncodable, Clone, Debug)]
pub struct BroadcastMsgRequest(Vec<u8>);
#[derive(SerialDecodable, SerialEncodable, Clone, Debug)]
pub struct LogResponse {
node_id: NodeId,
@@ -119,6 +122,7 @@ pub enum NetMsgMethod {
LogRequest = 1,
VoteResponse = 2,
VoteRequest = 3,
BroadcastRequest = 4,
}
impl Encodable for NetMsgMethod {
@@ -128,6 +132,7 @@ impl Encodable for NetMsgMethod {
Self::LogRequest => 1,
Self::VoteResponse => 2,
Self::VoteRequest => 3,
Self::BroadcastRequest => 4,
};
(len as u8).encode(s)
}
@@ -140,7 +145,8 @@ impl Decodable for NetMsgMethod {
0 => Self::LogResponse,
1 => Self::LogRequest,
2 => Self::VoteResponse,
_ => Self::VoteRequest,
3 => Self::VoteRequest,
_ => Self::BroadcastRequest,
})
}
}

View File

@@ -9,7 +9,7 @@ use crate::{net, Result};
use super::{NetMsg, NodeId};
pub struct ProtocolRaft {
id: NodeId,
id: Option<NodeId>,
jobsman: net::ProtocolJobsManagerPtr,
notify_queue_sender: async_channel::Sender<NetMsg>,
msg_sub: net::MessageSubscription<NetMsg>,
@@ -19,7 +19,7 @@ pub struct ProtocolRaft {
impl ProtocolRaft {
pub async fn init(
id: NodeId,
id: Option<NodeId>,
channel: net::ChannelPtr,
notify_queue_sender: async_channel::Sender<NetMsg>,
p2p: net::P2pPtr,
@@ -57,8 +57,8 @@ impl ProtocolRaft {
let msg = (*msg).clone();
self.p2p.broadcast(msg.clone()).await?;
if let Some(rec_id) = msg.recipient_id.clone() {
if rec_id != self.id {
if msg.recipient_id.is_some() && self.id.is_some() {
if msg.recipient_id != self.id {
continue
}
}

View File

@@ -16,8 +16,8 @@ use crate::{
};
use super::{
DataStore, Log, LogRequest, LogResponse, Logs, NetMsg, NetMsgMethod, NodeId, ProtocolRaft,
Role, VoteRequest, VoteResponse,
BroadcastMsgRequest, DataStore, Log, LogRequest, LogResponse, Logs, NetMsg, NetMsgMethod,
NodeId, ProtocolRaft, Role, VoteRequest, VoteResponse,
};
const HEARTBEATTIMEOUT: u64 = 100;
@@ -29,7 +29,9 @@ type Sender = (async_channel::Sender<NetMsg>, async_channel::Receiver<NetMsg>);
pub struct Raft<T> {
// this will be derived from the ip
id: NodeId,
// if the node doesn't have an id then will become a listener and doesn't have the right
// to request/response votes or response a confirmation for log
id: Option<NodeId>,
// these five vars should be on local storage
current_term: u64,
@@ -60,7 +62,7 @@ pub struct Raft<T> {
}
impl<T: Decodable + Encodable + Clone> Raft<T> {
pub fn new(addr: SocketAddr, db_path: PathBuf) -> Result<Self> {
pub fn new(addr: Option<SocketAddr>, db_path: PathBuf) -> Result<Self> {
if db_path.to_str().is_none() {
error!(target: "raft", "datastore path is incorrect");
return Err(Error::ParseFailed("unable to parse pathbuf to str"))
@@ -90,7 +92,7 @@ impl<T: Decodable + Encodable + Clone> Raft<T> {
let sender = async_channel::unbounded::<NetMsg>();
Ok(Self {
id: NodeId::from(addr),
id: addr.map(NodeId::from),
current_term,
voted_for,
logs,
@@ -199,12 +201,20 @@ impl<T: Decodable + Encodable + Clone> Raft<T> {
let log = Log { msg, term: self.current_term };
self.push_log(&log)?;
self.acked_length.insert(self.id.clone(), self.logs.len());
self.acked_length.insert(self.id.clone().unwrap(), self.logs.len());
let nodes = self.nodes.lock().await.clone();
for node in nodes.iter() {
self.update_logs(node.0).await?;
}
} else {
let b_msg = BroadcastMsgRequest(serialize(msg));
self.send(
self.current_leader.clone(),
&serialize(&b_msg),
NetMsgMethod::BroadcastRequest,
)
.await?;
}
Ok(())
}
@@ -227,6 +237,11 @@ impl<T: Decodable + Encodable + Clone> Raft<T> {
let vr: VoteRequest = deserialize(&msg.payload)?;
self.receive_vote_request(vr).await?;
}
NetMsgMethod::BroadcastRequest => {
let vr: BroadcastMsgRequest = deserialize(&msg.payload)?;
let d: T = deserialize(&vr.0)?;
self.broadcast_msg(&d).await?;
}
}
Ok(())
}
@@ -253,15 +268,22 @@ impl<T: Decodable + Encodable + Clone> Raft<T> {
}
async fn send_vote_request(&mut self) -> Result<()> {
// this will prevent the node to become a candidate
if self.id.is_none() {
return Ok(())
}
let self_id = self.id.clone().unwrap();
self.set_current_term(&(self.current_term + 1))?;
self.role = Role::Candidate;
self.set_voted_for(&Some(self.id.clone()))?;
self.votes_received.push(self.id.clone());
self.set_voted_for(&Some(self_id.clone()))?;
self.votes_received.push(self_id.clone());
self.reset_last_term();
let request = VoteRequest {
node_id: self.id.clone(),
node_id: self_id,
current_term: self.current_term,
log_length: self.logs.len(),
last_term: self.last_term,
@@ -272,6 +294,10 @@ impl<T: Decodable + Encodable + Clone> Raft<T> {
}
async fn receive_vote_request(&mut self, vr: VoteRequest) -> Result<()> {
if self.id.is_none() {
return Ok(())
}
if vr.current_term > self.current_term {
self.set_current_term(&vr.current_term)?;
self.set_voted_for(&None)?;
@@ -291,8 +317,11 @@ impl<T: Decodable + Encodable + Clone> Raft<T> {
true
};
let mut response =
VoteResponse { node_id: self.id.clone(), current_term: self.current_term, ok: false };
let mut response = VoteResponse {
node_id: self.id.clone().unwrap(),
current_term: self.current_term,
ok: false,
};
if vr.current_term == self.current_term && vote_ok && vote {
self.set_voted_for(&Some(vr.node_id.clone()))?;
@@ -310,7 +339,7 @@ impl<T: Decodable + Encodable + Clone> Raft<T> {
let nodes = self.nodes.lock().await;
if self.votes_received.len() >= ((nodes.len() + 1) / 2) {
self.role = Role::Leader;
self.current_leader = Some(self.id.clone());
self.current_leader = Some(self.id.clone().unwrap());
for node in nodes.iter() {
self.sent_length.insert(node.0.clone(), self.logs.len());
self.acked_length.insert(node.0.clone(), 0);
@@ -337,7 +366,7 @@ impl<T: Decodable + Encodable + Clone> Raft<T> {
}
let request = LogRequest {
leader_id: self.id.clone(),
leader_id: self.id.clone().unwrap(),
current_term: self.current_term,
prefix_len,
prefix_term,
@@ -363,17 +392,22 @@ impl<T: Decodable + Encodable + Clone> Raft<T> {
let ok = (self.logs.len() >= lr.prefix_len) &&
(lr.prefix_len == 0 || self.logs.get(lr.prefix_len - 1).term == lr.prefix_term);
let response: LogResponse = if lr.current_term == self.current_term && ok {
let mut ack = 0;
if lr.current_term == self.current_term && ok {
self.append_log(lr.prefix_len, lr.commit_length, &lr.suffix).await?;
let ack = lr.prefix_len + lr.suffix.len();
LogResponse { node_id: self.id.clone(), current_term: self.current_term, ack, ok }
} else {
LogResponse {
node_id: self.id.clone(),
current_term: self.current_term,
ack: 0,
ok: false,
}
ack = lr.prefix_len + lr.suffix.len();
}
if self.id.is_none() {
return Ok(())
}
let response = LogResponse {
node_id: self.id.clone().unwrap(),
current_term: self.current_term,
ack,
ok,
};
let payload = serialize(&response);