mirror of
https://github.com/darkrenaissance/darkfi.git
synced 2026-04-28 03:00:18 -04:00
src/raft: using get() instead of indexing into hashmap and vector to avoid panic!
This commit is contained in:
@@ -1,8 +1,8 @@
|
||||
use std::{io, net::SocketAddr};
|
||||
use std::{collections::HashMap, io, net::SocketAddr};
|
||||
|
||||
use crate::{
|
||||
util::serial::{serialize, Decodable, Encodable, SerialDecodable, SerialEncodable, VarInt},
|
||||
Result,
|
||||
Error, Result,
|
||||
};
|
||||
|
||||
mod datastore;
|
||||
@@ -90,16 +90,27 @@ impl Logs {
|
||||
self.0.push(d.clone());
|
||||
}
|
||||
|
||||
pub fn slice_from(&self, start: u64) -> Self {
|
||||
Self(self.0[start as usize..].to_vec())
|
||||
pub fn slice_from(&self, start: u64) -> Option<Self> {
|
||||
if self.len() >= start {
|
||||
return Some(Self(self.0[start as usize..].to_vec()))
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
pub fn slice_to(&self, end: u64) -> Self {
|
||||
Self(self.0[..end as usize].to_vec())
|
||||
for i in (0..end).rev() {
|
||||
if self.len() >= i {
|
||||
return Self(self.0[..i as usize].to_vec())
|
||||
}
|
||||
}
|
||||
Self(vec![])
|
||||
}
|
||||
|
||||
pub fn get(&self, index: u64) -> Log {
|
||||
self.0[index as usize].clone()
|
||||
pub fn get(&self, index: u64) -> Result<Log> {
|
||||
match self.0.get(index as usize) {
|
||||
Some(l) => Ok(l.clone()),
|
||||
None => Err(Error::RaftError("unable to indexing into vector".into())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_vec(&self) -> Vec<Log> {
|
||||
@@ -107,6 +118,22 @@ impl Logs {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct MapLength(pub HashMap<NodeId, u64>);
|
||||
|
||||
impl MapLength {
|
||||
pub fn get(&self, key: &NodeId) -> Result<u64> {
|
||||
match self.0.get(key) {
|
||||
Some(v) => Ok(v.clone()),
|
||||
None => Err(Error::RaftError("unable to indexing into HashMap".into())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, key: &NodeId, value: u64) {
|
||||
self.0.insert(key.clone(), value);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(SerialDecodable, SerialEncodable, Clone, Debug)]
|
||||
pub struct NetMsg {
|
||||
id: u32,
|
||||
|
||||
@@ -16,8 +16,8 @@ use crate::{
|
||||
};
|
||||
|
||||
use super::{
|
||||
BroadcastMsgRequest, DataStore, Log, LogRequest, LogResponse, Logs, NetMsg, NetMsgMethod,
|
||||
NodeId, ProtocolRaft, Role, VoteRequest, VoteResponse,
|
||||
BroadcastMsgRequest, DataStore, Log, LogRequest, LogResponse, Logs, MapLength, NetMsg,
|
||||
NetMsgMethod, NodeId, ProtocolRaft, Role, VoteRequest, VoteResponse,
|
||||
};
|
||||
|
||||
const HEARTBEATTIMEOUT: u64 = 100;
|
||||
@@ -47,8 +47,8 @@ pub struct Raft<T> {
|
||||
|
||||
votes_received: Vec<NodeId>,
|
||||
|
||||
sent_length: HashMap<NodeId, u64>,
|
||||
acked_length: HashMap<NodeId, u64>,
|
||||
sent_length: MapLength,
|
||||
acked_length: MapLength,
|
||||
|
||||
nodes: Arc<Mutex<HashMap<NodeId, SocketAddr>>>,
|
||||
|
||||
@@ -101,8 +101,8 @@ impl<T: Decodable + Encodable + Clone> Raft<T> {
|
||||
role: Role::Follower,
|
||||
current_leader: None,
|
||||
votes_received: vec![],
|
||||
sent_length: HashMap::new(),
|
||||
acked_length: HashMap::new(),
|
||||
sent_length: MapLength(HashMap::new()),
|
||||
acked_length: MapLength(HashMap::new()),
|
||||
nodes: Arc::new(Mutex::new(HashMap::new())),
|
||||
last_term: 0,
|
||||
sender,
|
||||
@@ -145,7 +145,7 @@ impl<T: Decodable + Encodable + Clone> Raft<T> {
|
||||
let msg: NetMsg = p2p_recv.recv().await.unwrap();
|
||||
match p2p_cloned.broadcast(msg).await {
|
||||
Ok(_) => {}
|
||||
Err(e) => error!(target: "raft", "error occurred during broadcasting a msg: {}", e)
|
||||
Err(e) => error!(target: "raft", "error occurred during broadcasting a msg: {}", e)
|
||||
}
|
||||
}
|
||||
}).detach();
|
||||
@@ -205,7 +205,7 @@ impl<T: Decodable + Encodable + Clone> Raft<T> {
|
||||
let log = Log { msg, term: self.current_term };
|
||||
self.push_log(&log)?;
|
||||
|
||||
self.acked_length.insert(self.id.clone().unwrap(), self.logs.len());
|
||||
self.acked_length.insert(&self.id.clone().unwrap(), self.logs.len());
|
||||
|
||||
let nodes = self.nodes.lock().await.clone();
|
||||
for node in nodes.iter() {
|
||||
@@ -361,8 +361,8 @@ impl<T: Decodable + Encodable + Clone> Raft<T> {
|
||||
self.role = Role::Leader;
|
||||
self.current_leader = Some(self.id.clone().unwrap());
|
||||
for node in nodes.iter() {
|
||||
self.sent_length.insert(node.0.clone(), self.logs.len());
|
||||
self.acked_length.insert(node.0.clone(), 0);
|
||||
self.sent_length.insert(&node.0, self.logs.len());
|
||||
self.acked_length.insert(&node.0, 0);
|
||||
self.update_logs(node.0).await?;
|
||||
}
|
||||
}
|
||||
@@ -377,12 +377,18 @@ impl<T: Decodable + Encodable + Clone> Raft<T> {
|
||||
}
|
||||
|
||||
async fn update_logs(&self, node_id: &NodeId) -> Result<()> {
|
||||
let prefix_len = self.sent_length[node_id];
|
||||
let suffix: Logs = self.logs.slice_from(prefix_len);
|
||||
let prefix_len = self.sent_length.get(node_id).unwrap().clone();
|
||||
|
||||
let suffix: Logs = if self.logs.slice_from(prefix_len).is_some() {
|
||||
self.logs.slice_from(prefix_len).unwrap()
|
||||
} else {
|
||||
return Ok(())
|
||||
};
|
||||
|
||||
let mut prefix_term = 0;
|
||||
|
||||
if prefix_len > 0 {
|
||||
prefix_term = self.logs.get(prefix_len - 1).term;
|
||||
prefix_term = self.logs.get(prefix_len - 1)?.term;
|
||||
}
|
||||
|
||||
let request = LogRequest {
|
||||
@@ -410,7 +416,7 @@ impl<T: Decodable + Encodable + Clone> Raft<T> {
|
||||
}
|
||||
|
||||
let ok = (self.logs.len() >= lr.prefix_len) &&
|
||||
(lr.prefix_len == 0 || self.logs.get(lr.prefix_len - 1).term == lr.prefix_term);
|
||||
(lr.prefix_len == 0 || self.logs.get(lr.prefix_len - 1)?.term == lr.prefix_term);
|
||||
|
||||
let mut ack = 0;
|
||||
|
||||
@@ -436,12 +442,12 @@ impl<T: Decodable + Encodable + Clone> Raft<T> {
|
||||
|
||||
async fn receive_log_response(&mut self, lr: LogResponse) -> Result<()> {
|
||||
if lr.current_term == self.current_term && self.role == Role::Leader {
|
||||
if lr.ok && lr.ack >= self.acked_length[&lr.node_id] {
|
||||
self.sent_length.insert(lr.node_id.clone(), lr.ack);
|
||||
self.acked_length.insert(lr.node_id, lr.ack);
|
||||
if lr.ok && lr.ack >= self.acked_length.get(&lr.node_id)? {
|
||||
self.sent_length.insert(&lr.node_id, lr.ack);
|
||||
self.acked_length.insert(&lr.node_id, lr.ack);
|
||||
self.commit_log().await?;
|
||||
} else if self.sent_length[&lr.node_id] > 0 {
|
||||
self.sent_length.insert(lr.node_id.clone(), self.sent_length[&lr.node_id] - 1);
|
||||
} else if self.sent_length.get(&lr.node_id)? > 0 {
|
||||
self.sent_length.insert(&lr.node_id, self.sent_length.get(&lr.node_id)? - 1);
|
||||
self.update_logs(&lr.node_id).await?;
|
||||
}
|
||||
} else if lr.current_term > self.current_term {
|
||||
@@ -462,7 +468,13 @@ impl<T: Decodable + Encodable + Clone> Raft<T> {
|
||||
}
|
||||
|
||||
fn acks(&self, nodes: HashMap<NodeId, SocketAddr>, length: u64) -> HashMap<NodeId, SocketAddr> {
|
||||
nodes.into_iter().filter(|n| self.acked_length[&n.0] >= length).collect()
|
||||
nodes
|
||||
.into_iter()
|
||||
.filter(|n| {
|
||||
let len = self.acked_length.get(&n.0);
|
||||
return len.is_ok() && len.unwrap() >= length
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn commit_log(&mut self) -> Result<()> {
|
||||
@@ -485,10 +497,10 @@ impl<T: Decodable + Encodable + Clone> Raft<T> {
|
||||
}
|
||||
|
||||
let max_ready = *ready.iter().max().unwrap();
|
||||
if max_ready > self.commit_length && self.logs.get(max_ready - 1).term == self.current_term
|
||||
if max_ready > self.commit_length && self.logs.get(max_ready - 1)?.term == self.current_term
|
||||
{
|
||||
for i in self.commit_length..(max_ready - 1) {
|
||||
self.push_commit(&self.logs.get(i).msg).await?;
|
||||
self.push_commit(&self.logs.get(i)?.msg).await?;
|
||||
}
|
||||
|
||||
self.set_commit_length(&max_ready)?;
|
||||
@@ -505,20 +517,20 @@ impl<T: Decodable + Encodable + Clone> Raft<T> {
|
||||
) -> Result<()> {
|
||||
if suffix.len() > 0 && self.logs.len() > prefix_len {
|
||||
let index = min(self.logs.len(), prefix_len + suffix.len()) - 1;
|
||||
if self.logs.get(index).term != suffix.get(index - prefix_len).term {
|
||||
self.push_logs(&self.logs.slice_to(prefix_len - 1))?;
|
||||
if self.logs.get(index)?.term != suffix.get(index - prefix_len)?.term {
|
||||
self.push_logs(&self.logs.slice_to(prefix_len))?;
|
||||
}
|
||||
}
|
||||
|
||||
if prefix_len + suffix.len() > self.logs.len() {
|
||||
for i in (self.logs.len() - prefix_len)..(suffix.len() - 1) {
|
||||
self.push_log(&suffix.get(i))?;
|
||||
self.push_log(&suffix.get(i)?)?;
|
||||
}
|
||||
}
|
||||
|
||||
if leader_commit > self.commit_length {
|
||||
for i in self.commit_length..(leader_commit - 1) {
|
||||
self.push_commit(&self.logs.get(i).msg).await?;
|
||||
self.push_commit(&self.logs.get(i)?.msg).await?;
|
||||
}
|
||||
self.set_commit_length(&leader_commit)?;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user