diff --git a/src/raft/raft.rs b/src/raft/consensus.rs similarity index 98% rename from src/raft/raft.rs rename to src/raft/consensus.rs index 488700348..c057ebd2b 100644 --- a/src/raft/raft.rs +++ b/src/raft/consensus.rs @@ -16,17 +16,17 @@ use crate::{ }; use super::{ - BroadcastMsgRequest, DataStore, Log, LogRequest, LogResponse, Logs, MapLength, NetMsg, - NetMsgMethod, NodeId, ProtocolRaft, Role, SyncRequest, SyncResponse, VoteRequest, VoteResponse, + primitives::{ + Broadcast, BroadcastMsgRequest, Log, LogRequest, LogResponse, Logs, MapLength, NetMsg, + NetMsgMethod, NodeId, Role, Sender, SyncRequest, SyncResponse, VoteRequest, VoteResponse, + }, + DataStore, ProtocolRaft, }; const HEARTBEATTIMEOUT: u64 = 100; const TIMEOUT: u64 = 300; const TIMEOUT_NODES: u64 = 300; -pub type Broadcast = (async_channel::Sender, async_channel::Receiver); -type Sender = (async_channel::Sender, async_channel::Receiver); - async fn load_node_ids_loop( nodes: Arc>>, p2p: net::P2pPtr, diff --git a/src/raft/datastore.rs b/src/raft/datastore.rs index e57a17506..443260d03 100644 --- a/src/raft/datastore.rs +++ b/src/raft/datastore.rs @@ -8,7 +8,7 @@ use crate::{ Result, }; -use super::{Log, NodeId}; +use super::primitives::{Log, NodeId}; const SLED_LOGS_TREE: &[u8] = b"_logs"; const SLED_COMMITS_TREE: &[u8] = b"_commits"; diff --git a/src/raft/mod.rs b/src/raft/mod.rs index c94bb8ac2..a74c9325a 100644 --- a/src/raft/mod.rs +++ b/src/raft/mod.rs @@ -1,239 +1,9 @@ -use std::{collections::HashMap, io, net::SocketAddr}; - -use crate::{ - util::serial::{serialize, Decodable, Encodable, SerialDecodable, SerialEncodable, VarInt}, - Error, Result, -}; - +mod consensus; mod datastore; +mod primitives; mod protocol_raft; -mod raft; +pub use consensus::Raft; use datastore::DataStore; -use protocol_raft::ProtocolRaft; -pub use raft::Raft; - -#[derive(PartialEq, Eq, Debug)] -pub enum Role { - Follower, - Candidate, - Leader, -} - -#[derive(SerialDecodable, SerialEncodable, Clone, Debug)] -pub struct SyncRequest { - logs_len: u64, - last_term: u64, -} - -#[derive(SerialDecodable, SerialEncodable, Clone, Debug)] -pub struct SyncResponse { - logs: Logs, - commit_length: u64, - leader_id: NodeId, - wipe: bool, -} - -#[derive(SerialDecodable, SerialEncodable, Clone, Debug)] -pub struct VoteRequest { - node_id: NodeId, - current_term: u64, - log_length: u64, - last_term: u64, -} - -#[derive(SerialDecodable, SerialEncodable, Clone, Debug)] -pub struct VoteResponse { - node_id: NodeId, - current_term: u64, - ok: bool, -} - -#[derive(SerialDecodable, SerialEncodable, Clone, Debug)] -pub struct LogRequest { - leader_id: NodeId, - current_term: u64, - prefix_len: u64, - prefix_term: u64, - commit_length: u64, - suffix: Logs, -} - -#[derive(SerialDecodable, SerialEncodable, Clone, Debug)] -pub struct BroadcastMsgRequest(Vec); - -#[derive(SerialDecodable, SerialEncodable, Clone, Debug)] -pub struct LogResponse { - node_id: NodeId, - current_term: u64, - ack: u64, - ok: bool, -} - -impl VoteResponse { - pub fn set_ok(&mut self, ok: bool) { - self.ok = ok; - } -} - -#[derive(Clone, Debug, SerialDecodable, SerialEncodable)] -pub struct Log { - term: u64, - msg: Vec, -} - -#[derive(Clone, Debug, Eq, PartialEq, Hash, SerialDecodable, SerialEncodable)] -pub struct NodeId(pub Vec); - -impl From for NodeId { - fn from(addr: SocketAddr) -> Self { - let ser = serialize(&addr); - let hash = blake3::hash(&ser).as_bytes().to_vec(); - Self(hash) - } -} - -#[derive(Clone, Debug)] -pub struct Logs(pub Vec); - -impl Logs { - pub fn len(&self) -> u64 { - self.0.len() as u64 - } - pub fn is_empty(&self) -> bool { - self.0.is_empty() - } - pub fn push(&mut self, d: &Log) { - self.0.push(d.clone()); - } - - pub fn slice_from(&self, start: u64) -> Option { - if self.len() >= start { - return Some(Self(self.0[start as usize..].to_vec())) - } - None - } - - pub fn slice_to(&self, end: u64) -> Self { - for i in (0..end).rev() { - if self.len() >= i { - return Self(self.0[..i as usize].to_vec()) - } - } - Self(vec![]) - } - - pub fn get(&self, index: u64) -> Result { - match self.0.get(index as usize) { - Some(l) => Ok(l.clone()), - None => Err(Error::RaftError("unable to indexing into vector".into())), - } - } - - pub fn to_vec(&self) -> Vec { - self.0.clone() - } -} - -#[derive(Clone, Debug)] -pub struct MapLength(pub HashMap); - -impl MapLength { - pub fn get(&self, key: &NodeId) -> Result { - match self.0.get(key) { - Some(v) => Ok(*v), - None => Err(Error::RaftError("unable to indexing into HashMap".into())), - } - } - - pub fn insert(&mut self, key: &NodeId, value: u64) { - self.0.insert(key.clone(), value); - } -} - -#[derive(SerialDecodable, SerialEncodable, Clone, Debug)] -pub struct NetMsg { - id: u64, - recipient_id: Option, - method: NetMsgMethod, - payload: Vec, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -#[repr(u8)] -pub enum NetMsgMethod { - LogResponse = 0, - LogRequest = 1, - VoteResponse = 2, - VoteRequest = 3, - BroadcastRequest = 4, - // this only used for listener node - SyncRequest = 5, - SyncResponse = 6, -} - -impl Encodable for NetMsgMethod { - fn encode(&self, s: S) -> Result { - let len: usize = match self { - Self::LogResponse => 0, - Self::LogRequest => 1, - Self::VoteResponse => 2, - Self::VoteRequest => 3, - Self::BroadcastRequest => 4, - Self::SyncRequest => 5, - Self::SyncResponse => 6, - }; - (len as u8).encode(s) - } -} - -impl Decodable for NetMsgMethod { - fn decode(d: D) -> Result { - let com: u8 = Decodable::decode(d)?; - Ok(match com { - 0 => Self::LogResponse, - 1 => Self::LogRequest, - 2 => Self::VoteResponse, - 3 => Self::VoteRequest, - 4 => Self::BroadcastRequest, - 5 => Self::SyncRequest, - _ => Self::SyncResponse, - }) - } -} - -impl Encodable for Logs { - fn encode(&self, s: S) -> Result { - encode_vec(&self.0, s) - } -} - -impl Decodable for Logs { - fn decode(d: D) -> Result { - Ok(Self(decode_vec(d)?)) - } -} - -fn encode_vec(vec: &[T], mut s: S) -> Result { - let mut len = 0; - len += VarInt(vec.len() as u64).encode(&mut s)?; - for c in vec.iter() { - len += c.encode(&mut s)?; - } - Ok(len) -} - -fn decode_vec(mut d: D) -> Result> { - let len = VarInt::decode(&mut d)?.0; - let mut ret = Vec::with_capacity(len as usize); - for _ in 0..len { - ret.push(Decodable::decode(&mut d)?); - } - Ok(ret) -} - -#[cfg(test)] -mod tests { - #[test] - fn it_works() {} -} +pub use primitives::NetMsg; +pub use protocol_raft::ProtocolRaft; diff --git a/src/raft/primitives.rs b/src/raft/primitives.rs new file mode 100644 index 000000000..81659cf5a --- /dev/null +++ b/src/raft/primitives.rs @@ -0,0 +1,228 @@ +use std::{collections::HashMap, io, net::SocketAddr}; + +use crate::{ + util::serial::{serialize, Decodable, Encodable, SerialDecodable, SerialEncodable, VarInt}, + Error, Result, +}; + +pub type Broadcast = (async_channel::Sender, async_channel::Receiver); +pub type Sender = (async_channel::Sender, async_channel::Receiver); + +#[derive(PartialEq, Eq, Debug)] +pub enum Role { + Follower, + Candidate, + Leader, +} + +#[derive(SerialDecodable, SerialEncodable, Clone, Debug)] +pub struct SyncRequest { + pub logs_len: u64, + pub last_term: u64, +} + +#[derive(SerialDecodable, SerialEncodable, Clone, Debug)] +pub struct SyncResponse { + pub logs: Logs, + pub commit_length: u64, + pub leader_id: NodeId, + pub wipe: bool, +} + +#[derive(SerialDecodable, SerialEncodable, Clone, Debug)] +pub struct VoteRequest { + pub node_id: NodeId, + pub current_term: u64, + pub log_length: u64, + pub last_term: u64, +} + +#[derive(SerialDecodable, SerialEncodable, Clone, Debug)] +pub struct VoteResponse { + pub node_id: NodeId, + pub current_term: u64, + pub ok: bool, +} + +#[derive(SerialDecodable, SerialEncodable, Clone, Debug)] +pub struct LogRequest { + pub leader_id: NodeId, + pub current_term: u64, + pub prefix_len: u64, + pub prefix_term: u64, + pub commit_length: u64, + pub suffix: Logs, +} + +#[derive(SerialDecodable, SerialEncodable, Clone, Debug)] +pub struct BroadcastMsgRequest(pub Vec); + +#[derive(SerialDecodable, SerialEncodable, Clone, Debug)] +pub struct LogResponse { + pub node_id: NodeId, + pub current_term: u64, + pub ack: u64, + pub ok: bool, +} + +impl VoteResponse { + pub fn set_ok(&mut self, ok: bool) { + self.ok = ok; + } +} + +#[derive(Clone, Debug, SerialDecodable, SerialEncodable)] +pub struct Log { + pub term: u64, + pub msg: Vec, +} + +#[derive(Clone, Debug, Eq, PartialEq, Hash, SerialDecodable, SerialEncodable)] +pub struct NodeId(pub Vec); + +impl From for NodeId { + fn from(addr: SocketAddr) -> Self { + let ser = serialize(&addr); + let hash = blake3::hash(&ser).as_bytes().to_vec(); + Self(hash) + } +} + +#[derive(Clone, Debug)] +pub struct Logs(pub Vec); + +impl Logs { + pub fn len(&self) -> u64 { + self.0.len() as u64 + } + pub fn is_empty(&self) -> bool { + self.0.is_empty() + } + pub fn push(&mut self, d: &Log) { + self.0.push(d.clone()); + } + + pub fn slice_from(&self, start: u64) -> Option { + if self.len() >= start { + return Some(Self(self.0[start as usize..].to_vec())) + } + None + } + + pub fn slice_to(&self, end: u64) -> Self { + for i in (0..end).rev() { + if self.len() >= i { + return Self(self.0[..i as usize].to_vec()) + } + } + Self(vec![]) + } + + pub fn get(&self, index: u64) -> Result { + match self.0.get(index as usize) { + Some(l) => Ok(l.clone()), + None => Err(Error::RaftError("unable to indexing into vector".into())), + } + } + + pub fn to_vec(&self) -> Vec { + self.0.clone() + } +} + +#[derive(Clone, Debug)] +pub struct MapLength(pub HashMap); + +impl MapLength { + pub fn get(&self, key: &NodeId) -> Result { + match self.0.get(key) { + Some(v) => Ok(*v), + None => Err(Error::RaftError("unable to indexing into HashMap".into())), + } + } + + pub fn insert(&mut self, key: &NodeId, value: u64) { + self.0.insert(key.clone(), value); + } +} + +#[derive(SerialDecodable, SerialEncodable, Clone, Debug)] +pub struct NetMsg { + pub id: u64, + pub recipient_id: Option, + pub method: NetMsgMethod, + pub payload: Vec, +} + +#[derive(Clone, Debug, PartialEq, Eq)] +#[repr(u8)] +pub enum NetMsgMethod { + LogResponse = 0, + LogRequest = 1, + VoteResponse = 2, + VoteRequest = 3, + BroadcastRequest = 4, + // this only used for listener node + SyncRequest = 5, + SyncResponse = 6, +} + +impl Encodable for NetMsgMethod { + fn encode(&self, s: S) -> Result { + let len: usize = match self { + Self::LogResponse => 0, + Self::LogRequest => 1, + Self::VoteResponse => 2, + Self::VoteRequest => 3, + Self::BroadcastRequest => 4, + Self::SyncRequest => 5, + Self::SyncResponse => 6, + }; + (len as u8).encode(s) + } +} + +impl Decodable for NetMsgMethod { + fn decode(d: D) -> Result { + let com: u8 = Decodable::decode(d)?; + Ok(match com { + 0 => Self::LogResponse, + 1 => Self::LogRequest, + 2 => Self::VoteResponse, + 3 => Self::VoteRequest, + 4 => Self::BroadcastRequest, + 5 => Self::SyncRequest, + _ => Self::SyncResponse, + }) + } +} + +impl Encodable for Logs { + fn encode(&self, s: S) -> Result { + encode_vec(&self.0, s) + } +} + +impl Decodable for Logs { + fn decode(d: D) -> Result { + Ok(Self(decode_vec(d)?)) + } +} + +fn encode_vec(vec: &[T], mut s: S) -> Result { + let mut len = 0; + len += VarInt(vec.len() as u64).encode(&mut s)?; + for c in vec.iter() { + len += c.encode(&mut s)?; + } + Ok(len) +} + +fn decode_vec(mut d: D) -> Result> { + let len = VarInt::decode(&mut d)?.0; + let mut ret = Vec::with_capacity(len as usize); + for _ in 0..len { + ret.push(Decodable::decode(&mut d)?); + } + Ok(ret) +} diff --git a/src/raft/protocol_raft.rs b/src/raft/protocol_raft.rs index f1ad96b19..67a882de6 100644 --- a/src/raft/protocol_raft.rs +++ b/src/raft/protocol_raft.rs @@ -6,7 +6,7 @@ use log::debug; use crate::{net, Result}; -use super::{NetMsg, NetMsgMethod, NodeId}; +use super::primitives::{NetMsg, NetMsgMethod, NodeId}; pub struct ProtocolRaft { id: Option,