From 75a132137214bfb62171b51ea2a42b16a09bfa84 Mon Sep 17 00:00:00 2001 From: ghassmo Date: Tue, 12 Apr 2022 17:01:27 +0400 Subject: [PATCH] src/raft: using get() instead of indexing into hashmap and vector to avoid panic! --- src/raft/mod.rs | 41 +++++++++++++++++++++++++------ src/raft/raft.rs | 64 ++++++++++++++++++++++++++++-------------------- 2 files changed, 72 insertions(+), 33 deletions(-) diff --git a/src/raft/mod.rs b/src/raft/mod.rs index e80788652..cfb24251c 100644 --- a/src/raft/mod.rs +++ b/src/raft/mod.rs @@ -1,8 +1,8 @@ -use std::{io, net::SocketAddr}; +use std::{collections::HashMap, io, net::SocketAddr}; use crate::{ util::serial::{serialize, Decodable, Encodable, SerialDecodable, SerialEncodable, VarInt}, - Result, + Error, Result, }; mod datastore; @@ -90,16 +90,27 @@ impl Logs { self.0.push(d.clone()); } - pub fn slice_from(&self, start: u64) -> Self { - Self(self.0[start as usize..].to_vec()) + pub fn slice_from(&self, start: u64) -> Option { + if self.len() >= start { + return Some(Self(self.0[start as usize..].to_vec())) + } + None } pub fn slice_to(&self, end: u64) -> Self { - Self(self.0[..end as usize].to_vec()) + for i in (0..end).rev() { + if self.len() >= i { + return Self(self.0[..i as usize].to_vec()) + } + } + Self(vec![]) } - pub fn get(&self, index: u64) -> Log { - self.0[index as usize].clone() + pub fn get(&self, index: u64) -> Result { + match self.0.get(index as usize) { + Some(l) => Ok(l.clone()), + None => Err(Error::RaftError("unable to indexing into vector".into())), + } } pub fn to_vec(&self) -> Vec { @@ -107,6 +118,22 @@ impl Logs { } } +#[derive(Clone, Debug)] +pub struct MapLength(pub HashMap); + +impl MapLength { + pub fn get(&self, key: &NodeId) -> Result { + match self.0.get(key) { + Some(v) => Ok(v.clone()), + None => Err(Error::RaftError("unable to indexing into HashMap".into())), + } + } + + pub fn insert(&mut self, key: &NodeId, value: u64) { + self.0.insert(key.clone(), value); + } +} + #[derive(SerialDecodable, SerialEncodable, Clone, Debug)] pub struct NetMsg { id: u32, diff --git a/src/raft/raft.rs b/src/raft/raft.rs index 3079da860..fff444838 100644 --- a/src/raft/raft.rs +++ b/src/raft/raft.rs @@ -16,8 +16,8 @@ use crate::{ }; use super::{ - BroadcastMsgRequest, DataStore, Log, LogRequest, LogResponse, Logs, NetMsg, NetMsgMethod, - NodeId, ProtocolRaft, Role, VoteRequest, VoteResponse, + BroadcastMsgRequest, DataStore, Log, LogRequest, LogResponse, Logs, MapLength, NetMsg, + NetMsgMethod, NodeId, ProtocolRaft, Role, VoteRequest, VoteResponse, }; const HEARTBEATTIMEOUT: u64 = 100; @@ -47,8 +47,8 @@ pub struct Raft { votes_received: Vec, - sent_length: HashMap, - acked_length: HashMap, + sent_length: MapLength, + acked_length: MapLength, nodes: Arc>>, @@ -101,8 +101,8 @@ impl Raft { role: Role::Follower, current_leader: None, votes_received: vec![], - sent_length: HashMap::new(), - acked_length: HashMap::new(), + sent_length: MapLength(HashMap::new()), + acked_length: MapLength(HashMap::new()), nodes: Arc::new(Mutex::new(HashMap::new())), last_term: 0, sender, @@ -145,7 +145,7 @@ impl Raft { let msg: NetMsg = p2p_recv.recv().await.unwrap(); match p2p_cloned.broadcast(msg).await { Ok(_) => {} - Err(e) => error!(target: "raft", "error occurred during broadcasting a msg: {}", e) + Err(e) => error!(target: "raft", "error occurred during broadcasting a msg: {}", e) } } }).detach(); @@ -205,7 +205,7 @@ impl Raft { let log = Log { msg, term: self.current_term }; self.push_log(&log)?; - self.acked_length.insert(self.id.clone().unwrap(), self.logs.len()); + self.acked_length.insert(&self.id.clone().unwrap(), self.logs.len()); let nodes = self.nodes.lock().await.clone(); for node in nodes.iter() { @@ -361,8 +361,8 @@ impl Raft { self.role = Role::Leader; self.current_leader = Some(self.id.clone().unwrap()); for node in nodes.iter() { - self.sent_length.insert(node.0.clone(), self.logs.len()); - self.acked_length.insert(node.0.clone(), 0); + self.sent_length.insert(&node.0, self.logs.len()); + self.acked_length.insert(&node.0, 0); self.update_logs(node.0).await?; } } @@ -377,12 +377,18 @@ impl Raft { } async fn update_logs(&self, node_id: &NodeId) -> Result<()> { - let prefix_len = self.sent_length[node_id]; - let suffix: Logs = self.logs.slice_from(prefix_len); + let prefix_len = self.sent_length.get(node_id).unwrap().clone(); + + let suffix: Logs = if self.logs.slice_from(prefix_len).is_some() { + self.logs.slice_from(prefix_len).unwrap() + } else { + return Ok(()) + }; let mut prefix_term = 0; + if prefix_len > 0 { - prefix_term = self.logs.get(prefix_len - 1).term; + prefix_term = self.logs.get(prefix_len - 1)?.term; } let request = LogRequest { @@ -410,7 +416,7 @@ impl Raft { } let ok = (self.logs.len() >= lr.prefix_len) && - (lr.prefix_len == 0 || self.logs.get(lr.prefix_len - 1).term == lr.prefix_term); + (lr.prefix_len == 0 || self.logs.get(lr.prefix_len - 1)?.term == lr.prefix_term); let mut ack = 0; @@ -436,12 +442,12 @@ impl Raft { async fn receive_log_response(&mut self, lr: LogResponse) -> Result<()> { if lr.current_term == self.current_term && self.role == Role::Leader { - if lr.ok && lr.ack >= self.acked_length[&lr.node_id] { - self.sent_length.insert(lr.node_id.clone(), lr.ack); - self.acked_length.insert(lr.node_id, lr.ack); + if lr.ok && lr.ack >= self.acked_length.get(&lr.node_id)? { + self.sent_length.insert(&lr.node_id, lr.ack); + self.acked_length.insert(&lr.node_id, lr.ack); self.commit_log().await?; - } else if self.sent_length[&lr.node_id] > 0 { - self.sent_length.insert(lr.node_id.clone(), self.sent_length[&lr.node_id] - 1); + } else if self.sent_length.get(&lr.node_id)? > 0 { + self.sent_length.insert(&lr.node_id, self.sent_length.get(&lr.node_id)? - 1); self.update_logs(&lr.node_id).await?; } } else if lr.current_term > self.current_term { @@ -462,7 +468,13 @@ impl Raft { } fn acks(&self, nodes: HashMap, length: u64) -> HashMap { - nodes.into_iter().filter(|n| self.acked_length[&n.0] >= length).collect() + nodes + .into_iter() + .filter(|n| { + let len = self.acked_length.get(&n.0); + return len.is_ok() && len.unwrap() >= length + }) + .collect() } async fn commit_log(&mut self) -> Result<()> { @@ -485,10 +497,10 @@ impl Raft { } let max_ready = *ready.iter().max().unwrap(); - if max_ready > self.commit_length && self.logs.get(max_ready - 1).term == self.current_term + if max_ready > self.commit_length && self.logs.get(max_ready - 1)?.term == self.current_term { for i in self.commit_length..(max_ready - 1) { - self.push_commit(&self.logs.get(i).msg).await?; + self.push_commit(&self.logs.get(i)?.msg).await?; } self.set_commit_length(&max_ready)?; @@ -505,20 +517,20 @@ impl Raft { ) -> Result<()> { if suffix.len() > 0 && self.logs.len() > prefix_len { let index = min(self.logs.len(), prefix_len + suffix.len()) - 1; - if self.logs.get(index).term != suffix.get(index - prefix_len).term { - self.push_logs(&self.logs.slice_to(prefix_len - 1))?; + if self.logs.get(index)?.term != suffix.get(index - prefix_len)?.term { + self.push_logs(&self.logs.slice_to(prefix_len))?; } } if prefix_len + suffix.len() > self.logs.len() { for i in (self.logs.len() - prefix_len)..(suffix.len() - 1) { - self.push_log(&suffix.get(i))?; + self.push_log(&suffix.get(i)?)?; } } if leader_commit > self.commit_length { for i in self.commit_length..(leader_commit - 1) { - self.push_commit(&self.logs.get(i).msg).await?; + self.push_commit(&self.logs.get(i)?.msg).await?; } self.set_commit_length(&leader_commit)?; }