Remove generic from MockChannelFactory

This commit is contained in:
sinu
2023-05-12 09:27:52 -07:00
parent 73cc1e3ba4
commit 7cd41af8b6

View File

@@ -40,36 +40,64 @@ pub mod mock {
use super::*;
use std::{
collections::HashMap,
any::Any,
collections::{HashMap, HashSet},
sync::{Arc, Mutex},
};
use crate::duplex::DuplexChannel;
struct FactoryState<T> {
channel_buffer: HashMap<String, DuplexChannel<T>>,
#[derive(Default)]
struct FactoryState {
exists: HashSet<String>,
buffer: HashMap<String, Box<dyn Any + Send + 'static>>,
}
#[derive(Clone)]
pub struct MockMuxChannelFactory<T> {
state: Arc<Mutex<FactoryState<T>>>,
#[derive(Default, Clone)]
pub struct MockMuxChannelFactory {
state: Arc<Mutex<FactoryState>>,
}
impl<T> MockMuxChannelFactory<T>
where
T: Send + 'static,
{
impl MockMuxChannelFactory {
/// Creates a new mock mux channel factory
pub fn new() -> Self {
Self {
state: Arc::new(Mutex::new(FactoryState {
channel_buffer: HashMap::new(),
exists: HashSet::new(),
buffer: HashMap::new(),
})),
}
}
/// Sets up a channel with the provided id
pub fn setup_channel<T: Send + 'static>(
&self,
id: &str,
) -> Result<DuplexChannel<T>, MuxerError> {
let mut state = self.state.lock().unwrap();
if let Some(channel) = state.buffer.remove(id) {
if let Ok(channel) = channel.downcast::<DuplexChannel<T>>() {
Ok(*channel)
} else {
Err(MuxerError::InternalError(
"Failed to downcast channel".to_string(),
))
}
} else {
if !state.exists.insert(id.to_string()) {
return Err(MuxerError::DuplicateStreamId(id.to_string()));
}
let (channel_0, channel_1) = DuplexChannel::new();
state.buffer.insert(id.to_string(), Box::new(channel_1));
Ok(channel_0)
}
}
}
#[async_trait]
impl<T> MuxChannelControl<T> for MockMuxChannelFactory<T>
impl<T> MuxChannelControl<T> for MockMuxChannelFactory
where
T: Send + 'static,
{
@@ -77,15 +105,8 @@ pub mod mock {
&mut self,
id: String,
) -> Result<Box<dyn Channel<T, Error = std::io::Error>>, MuxerError> {
let mut state = self.state.lock().unwrap();
let channel = if let Some(channel) = state.channel_buffer.remove(&id) {
Box::new(channel)
} else {
let (channel_0, channel_1) = DuplexChannel::new();
state.channel_buffer.insert(id, channel_1);
Box::new(channel_0)
};
Ok(channel)
self.setup_channel(&id)
.map(|c| Box::new(c) as Box<dyn Channel<T, Error = std::io::Error>>)
}
}