feat: add Pool type for pooling plugin instances (#696)

This commit is contained in:
zach
2025-06-06 10:14:22 -07:00
committed by GitHub
parent 30b4a7d2d3
commit 2732ca198d
6 changed files with 260 additions and 0 deletions

View File

@@ -29,6 +29,7 @@ pub(crate) mod manifest;
pub(crate) mod pdk;
mod plugin;
mod plugin_builder;
mod pool;
mod readonly_dir;
mod timer;
@@ -43,6 +44,7 @@ pub use plugin::{
CancelHandle, CompiledPlugin, Plugin, WasmInput, EXTISM_ENV_MODULE, EXTISM_USER_MODULE,
};
pub use plugin_builder::{DebugOptions, PluginBuilder};
pub use pool::{Pool, PoolBuilder, PoolPlugin};
pub(crate) use internal::{Internal, Wasi};
pub(crate) use timer::{Timer, TimerAction};

View File

@@ -191,6 +191,7 @@ pub(crate) fn profiling_strategy() -> ProfilingStrategy {
/// Defines an input type for Wasm data.
///
/// Types that implement `Into<WasmInput>` can be passed directly into `Plugin::new`
#[derive(Clone)]
pub enum WasmInput<'a> {
/// Raw Wasm module
Data(std::borrow::Cow<'a, [u8]>),

View File

@@ -33,6 +33,7 @@ impl Default for DebugOptions {
}
/// PluginBuilder is used to configure and create `Plugin` instances
#[derive(Clone)]
pub struct PluginBuilder<'a> {
pub(crate) source: WasmInput<'a>,
pub(crate) config: Option<wasmtime::Config>,

207
runtime/src/pool.rs Normal file
View File

@@ -0,0 +1,207 @@
use std::collections::HashMap;
use crate::{Error, FromBytesOwned, Plugin, PluginBuilder, ToBytes};
// `PoolBuilder` is used to configure and create `Pool`s
#[derive(Debug, Clone)]
pub struct PoolBuilder {
/// Max number of concurrent instances for a plugin - by default this is set to
/// the output of `std::thread::available_parallelism`
pub max_instances: usize,
}
impl PoolBuilder {
/// Create a `PoolBuilder` with default values
pub fn new() -> Self {
Self::default()
}
/// Set the max number of parallel instances
pub fn with_max_instances(mut self, n: usize) -> Self {
self.max_instances = n;
self
}
/// Create a new `Pool` with the given configuration
pub fn build(self) -> Pool {
Pool::new_from_builder(self)
}
}
impl Default for PoolBuilder {
fn default() -> Self {
PoolBuilder {
max_instances: std::thread::available_parallelism()
.expect("available parallelism")
.into(),
}
}
}
/// `PoolPlugin` is used by the pool to track the number of live instances of a particular plugin
#[derive(Clone, Debug)]
pub struct PoolPlugin(std::rc::Rc<std::cell::RefCell<Plugin>>);
impl PoolPlugin {
fn new(plugin: Plugin) -> Self {
Self(std::rc::Rc::new(std::cell::RefCell::new(plugin)))
}
/// Access the underlying plugin
pub fn plugin(&self) -> std::cell::RefMut<Plugin> {
self.0.borrow_mut()
}
/// Helper to call a plugin function on the underlying plugin
pub fn call<'a, Input: ToBytes<'a>, Output: FromBytesOwned>(
&self,
name: impl AsRef<str>,
input: Input,
) -> Result<Output, Error> {
self.plugin().call(name.as_ref(), input)
}
/// Helper to get the underlying plugin's ID
pub fn id(&self) -> uuid::Uuid {
self.plugin().id
}
}
type PluginSource = dyn Fn() -> Result<Plugin, Error>;
struct PoolInner<Key: std::fmt::Debug + Clone + std::hash::Hash + Eq = String> {
plugins: HashMap<Key, Box<PluginSource>>,
instances: HashMap<Key, Vec<PoolPlugin>>,
}
/// `Pool` manages threadsafe access to a limited number of instances of multiple plugins
#[derive(Clone)]
pub struct Pool<Key: std::fmt::Debug + Clone + std::hash::Hash + Eq = String> {
config: PoolBuilder,
inner: std::sync::Arc<std::sync::Mutex<PoolInner<Key>>>,
}
unsafe impl<T: std::fmt::Debug + Clone + std::hash::Hash + Eq> Send for Pool<T> {}
unsafe impl<T: std::fmt::Debug + Clone + std::hash::Hash + Eq> Sync for Pool<T> {}
impl<T: std::fmt::Debug + Clone + std::hash::Hash + Eq> Default for Pool<T> {
fn default() -> Self {
Self::new_from_builder(PoolBuilder::default())
}
}
impl<Key: std::fmt::Debug + Clone + std::hash::Hash + Eq> Pool<Key> {
/// Create a new pool with the defailt configuration
pub fn new() -> Self {
Self::default()
}
/// Create a new pool configured using a `PoolBuilder`
pub fn new_from_builder(builder: PoolBuilder) -> Self {
Pool {
config: builder,
inner: std::sync::Arc::new(std::sync::Mutex::new(PoolInner {
plugins: Default::default(),
instances: Default::default(),
})),
}
}
/// Add a plugin using a callback function
pub fn add<F: 'static + Fn() -> Result<Plugin, Error>>(&self, key: Key, source: F) {
let mut pool = self.inner.lock().unwrap();
if !pool.instances.contains_key(&key) {
pool.instances.insert(key.clone(), vec![]);
}
pool.plugins.insert(key, Box::new(source));
}
/// Add a plugin using a `PluginBuilder`
pub fn add_builder(&self, key: Key, source: PluginBuilder<'static>) {
let mut pool = self.inner.lock().unwrap();
if !pool.instances.contains_key(&key) {
pool.instances.insert(key.clone(), vec![]);
}
pool.plugins
.insert(key, Box::new(move || source.clone().build()));
}
fn find_available(&self, key: &Key) -> Result<Option<PoolPlugin>, Error> {
let mut pool = self.inner.lock().unwrap();
if let Some(entry) = pool.instances.get_mut(key) {
for instance in entry.iter() {
if std::rc::Rc::strong_count(&instance.0) == 1 {
return Ok(Some(instance.clone()));
}
}
}
Ok(None)
}
/// Get the number of live instances for a plugin
pub fn count(&self, key: &Key) -> usize {
self.inner
.lock()
.unwrap()
.instances
.get(key)
.map(|x| x.len())
.unwrap_or_default()
}
/// Get access to a plugin, this will create a new instance if needed (and allowed by the specified
/// max_instances). `Ok(None)` is returned if the timeout is reached before an available plugin could be
/// acquired
pub fn get(
&self,
key: &Key,
timeout: std::time::Duration,
) -> Result<Option<PoolPlugin>, Error> {
let start = std::time::Instant::now();
let max = self.config.max_instances;
if let Some(avail) = self.find_available(key)? {
return Ok(Some(avail));
}
{
let mut pool = self.inner.lock().unwrap();
if pool.instances.get(key).map(|x| x.len()).unwrap_or_default() < max {
if let Some(source) = pool.plugins.get(key) {
let plugin = source()?;
let instance = PoolPlugin::new(plugin);
let v = pool.instances.get_mut(key).unwrap();
v.push(instance);
return Ok(Some(v.last().unwrap().clone()));
}
}
}
loop {
if let Ok(Some(x)) = self.find_available(key) {
return Ok(Some(x));
}
if std::time::Instant::now() - start > timeout {
return Ok(None);
}
std::thread::sleep(std::time::Duration::from_millis(100));
}
}
/// Access a plugin in a callback function. This calls `Pool::get` then the provided
/// callback. `Ok(None)` is returned if the timeout is reached before an available
/// plugin could be acquired
pub fn with_plugin<T>(
&self,
key: &Key,
timeout: std::time::Duration,
f: impl FnOnce(&mut Plugin) -> Result<T, Error>,
) -> Result<Option<T>, Error> {
if let Some(plugin) = self.get(key, timeout)? {
return f(&mut plugin.plugin()).map(Some);
}
Ok(None)
}
}

View File

@@ -1,3 +1,4 @@
mod issues;
mod kernel;
mod pool;
mod runtime;

48
runtime/src/tests/pool.rs Normal file
View File

@@ -0,0 +1,48 @@
use crate::*;
fn run_thread(p: Pool<String>, i: u64) -> std::thread::JoinHandle<()> {
std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(i));
let s: String = p
.get(&"test".to_string(), std::time::Duration::from_secs(1))
.unwrap()
.unwrap()
.call("count_vowels", "abc")
.unwrap();
println!("{}", s);
})
}
#[test]
fn test_threads() {
for i in 1..=3 {
let data = include_bytes!("../../../wasm/code.wasm");
let pool: Pool<String> = PoolBuilder::new().with_max_instances(i).build();
let test = "test".to_string();
pool.add_builder(
test.clone(),
extism::PluginBuilder::new(extism::Manifest::new([extism::Wasm::data(data)]))
.with_wasi(true),
);
let mut threads = vec![];
threads.push(run_thread(pool.clone(), 1000));
threads.push(run_thread(pool.clone(), 1000));
threads.push(run_thread(pool.clone(), 1000));
threads.push(run_thread(pool.clone(), 1000));
threads.push(run_thread(pool.clone(), 1000));
threads.push(run_thread(pool.clone(), 1000));
threads.push(run_thread(pool.clone(), 500));
threads.push(run_thread(pool.clone(), 500));
threads.push(run_thread(pool.clone(), 500));
threads.push(run_thread(pool.clone(), 500));
threads.push(run_thread(pool.clone(), 500));
threads.push(run_thread(pool.clone(), 0));
for t in threads {
t.join().unwrap();
}
assert!(pool.count(&test) <= i);
}
}