feat(tasks): add WorkerPool with per-thread Worker state (#22154)

Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Georgios Konstantopoulos
2026-02-16 15:46:51 -08:00
committed by GitHub
parent 02513ecf3b
commit 2eec519bf9
6 changed files with 312 additions and 55 deletions

View File

@@ -276,6 +276,7 @@ where
F: DatabaseProviderROFactory<Provider: TrieCursorFactory + HashedCursorFactory>
+ Clone
+ Send
+ Sync
+ 'static,
{
// start preparing transactions immediately

View File

@@ -299,7 +299,7 @@ where
);
let ctx = self.ctx.clone();
self.executor.prewarming_pool().install(|| {
self.executor.prewarming_pool().install_fn(|| {
bal.par_iter().for_each_init(
|| (ctx.clone(), None::<CachedStateProvider<reth_provider::StateProviderBox>>),
|(ctx, provider), account| {

View File

@@ -36,6 +36,8 @@ pub mod shutdown;
#[cfg(feature = "rayon")]
pub mod pool;
#[cfg(feature = "rayon")]
pub use pool::{Worker, WorkerPool};
/// Lock-free ordered parallel iterator extension trait.
#[cfg(feature = "rayon")]

View File

@@ -1,10 +1,15 @@
//! Additional helpers for executing tracing calls
use std::{
any::Any,
cell::RefCell,
future::Future,
panic::{catch_unwind, AssertUnwindSafe},
pin::Pin,
sync::Arc,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
task::{ready, Context, Poll},
thread,
};
@@ -151,6 +156,168 @@ impl<T> Future for BlockingTaskHandle<T> {
#[non_exhaustive]
pub struct TokioBlockingTaskError;
thread_local! {
static WORKER: RefCell<Worker> = const { RefCell::new(Worker::new()) };
}
/// A rayon thread pool with per-thread [`Worker`] state.
///
/// Each thread in the pool has its own [`Worker`] that can hold arbitrary state via
/// [`Worker::init`]. The state is thread-local and accessible during [`install`](Self::install)
/// calls.
///
/// The pool supports multiple init/clear cycles, allowing reuse of the same threads with
/// different state configurations.
#[derive(Debug)]
pub struct WorkerPool {
pool: rayon::ThreadPool,
}
impl WorkerPool {
/// Creates a new `WorkerPool` with the given number of threads.
pub fn new(num_threads: usize) -> Result<Self, rayon::ThreadPoolBuildError> {
Self::from_builder(rayon::ThreadPoolBuilder::new().num_threads(num_threads))
}
/// Creates a new `WorkerPool` from a [`rayon::ThreadPoolBuilder`].
pub fn from_builder(
builder: rayon::ThreadPoolBuilder,
) -> Result<Self, rayon::ThreadPoolBuildError> {
Ok(Self { pool: builder.build()? })
}
/// Returns the total number of threads in the underlying rayon pool.
pub fn current_num_threads(&self) -> usize {
self.pool.current_num_threads()
}
/// Runs a closure on `num_threads` threads in the pool, giving mutable access to each
/// thread's [`Worker`].
///
/// Use this to initialize or re-initialize per-thread state via [`Worker::init`].
/// Only `num_threads` threads execute the closure; the rest skip it.
pub fn broadcast(&self, num_threads: usize, f: impl Fn(&mut Worker) + Sync) {
if num_threads >= self.pool.current_num_threads() {
// Fast path: run on every thread, no atomic coordination needed.
self.pool.broadcast(|_| {
WORKER.with_borrow_mut(|worker| f(worker));
});
} else {
let remaining = AtomicUsize::new(num_threads);
self.pool.broadcast(|_| {
// Atomically claim a slot; threads that can't decrement skip the closure.
let mut current = remaining.load(Ordering::Relaxed);
loop {
if current == 0 {
return;
}
match remaining.compare_exchange_weak(
current,
current - 1,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => break,
Err(actual) => current = actual,
}
}
WORKER.with_borrow_mut(|worker| f(worker));
});
}
}
/// Clears the state on every thread in the pool.
pub fn clear(&self) {
self.pool.broadcast(|_| {
WORKER.with_borrow_mut(Worker::clear);
});
}
/// Runs a closure on the pool with access to the calling thread's [`Worker`].
///
/// All rayon parallelism (e.g. `par_iter`) spawned inside the closure executes on this pool.
/// Each thread can access its own [`Worker`] via the provided reference or through additional
/// [`WorkerPool::with_worker`] calls.
pub fn install<R: Send>(&self, f: impl FnOnce(&Worker) -> R + Send) -> R {
self.pool.install(|| WORKER.with_borrow(|worker| f(worker)))
}
/// Runs a closure on the pool without worker state access.
///
/// Like [`install`](Self::install) but for closures that don't need per-thread [`Worker`]
/// state.
pub fn install_fn<R: Send>(&self, f: impl FnOnce() -> R + Send) -> R {
self.pool.install(f)
}
/// Spawns a closure on the pool.
pub fn spawn(&self, f: impl FnOnce() + Send + 'static) {
self.pool.spawn(f);
}
/// Access the current thread's [`Worker`] from within an [`install`](Self::install) closure.
///
/// This is useful for accessing the worker from inside `par_iter` where the initial `&Worker`
/// reference from `install` belongs to a different thread.
pub fn with_worker<R>(f: impl FnOnce(&Worker) -> R) -> R {
WORKER.with_borrow(|worker| f(worker))
}
}
/// Per-thread state container for a [`WorkerPool`].
///
/// Holds a type-erased `Box<dyn Any>` that can be initialized and accessed with concrete types
/// via [`init`](Self::init) and [`get`](Self::get).
#[derive(Debug, Default)]
pub struct Worker {
state: Option<Box<dyn Any>>,
}
impl Worker {
/// Creates a new empty `Worker`.
const fn new() -> Self {
Self { state: None }
}
/// Initializes the worker state.
///
/// If state of type `T` already exists, passes `Some(&mut T)` to the closure so resources
/// can be reused. On first init, passes `None`.
pub fn init<T: 'static>(&mut self, f: impl FnOnce(Option<&mut T>) -> T) {
let existing =
self.state.take().and_then(|mut b| b.downcast_mut::<T>().is_some().then_some(b));
let new_state = match existing {
Some(mut boxed) => {
let r = boxed.downcast_mut::<T>().expect("type checked above");
*r = f(Some(r));
boxed
}
None => Box::new(f(None)),
};
self.state = Some(new_state);
}
/// Returns a reference to the state, downcasted to `T`.
///
/// # Panics
///
/// Panics if the worker has not been initialized or if the type does not match.
pub fn get<T: 'static>(&self) -> &T {
self.state
.as_ref()
.expect("worker not initialized")
.downcast_ref::<T>()
.expect("worker state type mismatch")
}
/// Clears the worker state, dropping the contained value.
pub fn clear(&mut self) {
self.state = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
@@ -172,4 +339,89 @@ mod tests {
let res = res.await;
assert!(res.is_err());
}
#[test]
fn worker_pool_init_and_access() {
let pool = WorkerPool::new(2).unwrap();
pool.broadcast(2, |worker| {
worker.init::<Vec<u8>>(|_| vec![1, 2, 3]);
});
let sum: u8 = pool.install(|worker| {
let v = worker.get::<Vec<u8>>();
v.iter().sum()
});
assert_eq!(sum, 6);
pool.clear();
}
#[test]
fn worker_pool_reinit_reuses_resources() {
let pool = WorkerPool::new(1).unwrap();
pool.broadcast(1, |worker| {
worker.init::<Vec<u8>>(|existing| {
assert!(existing.is_none());
vec![1, 2, 3]
});
});
pool.broadcast(1, |worker| {
worker.init::<Vec<u8>>(|existing| {
let v = existing.expect("should have existing state");
assert_eq!(v, &mut vec![1, 2, 3]);
v.push(4);
std::mem::take(v)
});
});
let len = pool.install(|worker| worker.get::<Vec<u8>>().len());
assert_eq!(len, 4);
pool.clear();
}
#[test]
fn worker_pool_clear_and_reinit() {
let pool = WorkerPool::new(1).unwrap();
pool.broadcast(1, |worker| {
worker.init::<u64>(|_| 42);
});
let val = pool.install(|worker| *worker.get::<u64>());
assert_eq!(val, 42);
pool.clear();
pool.broadcast(1, |worker| {
worker.init::<String>(|_| "hello".to_string());
});
let val = pool.install(|worker| worker.get::<String>().clone());
assert_eq!(val, "hello");
pool.clear();
}
#[test]
fn worker_pool_par_iter_with_worker() {
use rayon::prelude::*;
let pool = WorkerPool::new(2).unwrap();
pool.broadcast(2, |worker| {
worker.init::<u64>(|_| 10);
});
let results: Vec<u64> = pool.install(|_| {
(0u64..4)
.into_par_iter()
.map(|i| WorkerPool::with_worker(|w| i + *w.get::<u64>()))
.collect()
});
assert_eq!(results, vec![10, 11, 12, 13]);
pool.clear();
}
}

View File

@@ -7,7 +7,7 @@
//! - [`BlockingTaskGuard`] for rate-limiting expensive operations (with `rayon` feature)
#[cfg(feature = "rayon")]
use crate::pool::{BlockingTaskGuard, BlockingTaskPool};
use crate::pool::{BlockingTaskGuard, BlockingTaskPool, WorkerPool};
use crate::{
metrics::{IncCounterOnDrop, TaskExecutorMetrics},
shutdown::{GracefulShutdown, GracefulShutdownGuard, Shutdown},
@@ -263,13 +263,13 @@ struct RuntimeInner {
blocking_guard: BlockingTaskGuard,
/// Proof storage worker pool (trie storage proof computation).
#[cfg(feature = "rayon")]
proof_storage_worker_pool: rayon::ThreadPool,
proof_storage_worker_pool: WorkerPool,
/// Proof account worker pool (trie account proof computation).
#[cfg(feature = "rayon")]
proof_account_worker_pool: rayon::ThreadPool,
proof_account_worker_pool: WorkerPool,
/// Prewarming pool (execution prewarming workers).
#[cfg(feature = "rayon")]
prewarming_pool: rayon::ThreadPool,
prewarming_pool: WorkerPool,
/// Handle to the spawned [`TaskManager`] background task.
/// The task monitors critical tasks for panics and fires the shutdown signal.
/// Can be taken via [`Runtime::take_task_manager_handle`] to poll for panic errors.
@@ -349,19 +349,19 @@ impl Runtime {
/// Get the proof storage worker pool.
#[cfg(feature = "rayon")]
pub fn proof_storage_worker_pool(&self) -> &rayon::ThreadPool {
pub fn proof_storage_worker_pool(&self) -> &WorkerPool {
&self.0.proof_storage_worker_pool
}
/// Get the proof account worker pool.
#[cfg(feature = "rayon")]
pub fn proof_account_worker_pool(&self) -> &rayon::ThreadPool {
pub fn proof_account_worker_pool(&self) -> &WorkerPool {
&self.0.proof_account_worker_pool
}
/// Get the prewarming pool.
#[cfg(feature = "rayon")]
pub fn prewarming_pool(&self) -> &rayon::ThreadPool {
pub fn prewarming_pool(&self) -> &WorkerPool {
&self.0.prewarming_pool
}
}
@@ -809,23 +809,26 @@ impl RuntimeBuilder {
let proof_storage_worker_threads =
config.rayon.proof_storage_worker_threads.unwrap_or(default_threads);
let proof_storage_worker_pool = rayon::ThreadPoolBuilder::new()
.num_threads(proof_storage_worker_threads)
.thread_name(|i| format!("proof-strg-{i:02}"))
.build()?;
let proof_storage_worker_pool = WorkerPool::from_builder(
rayon::ThreadPoolBuilder::new()
.num_threads(proof_storage_worker_threads)
.thread_name(|i| format!("proof-strg-{i:02}")),
)?;
let proof_account_worker_threads =
config.rayon.proof_account_worker_threads.unwrap_or(default_threads);
let proof_account_worker_pool = rayon::ThreadPoolBuilder::new()
.num_threads(proof_account_worker_threads)
.thread_name(|i| format!("proof-acct-{i:02}"))
.build()?;
let proof_account_worker_pool = WorkerPool::from_builder(
rayon::ThreadPoolBuilder::new()
.num_threads(proof_account_worker_threads)
.thread_name(|i| format!("proof-acct-{i:02}")),
)?;
let prewarming_threads = config.rayon.prewarming_threads.unwrap_or(default_threads);
let prewarming_pool = rayon::ThreadPoolBuilder::new()
.num_threads(prewarming_threads)
.thread_name(|i| format!("prewarm-{i:02}"))
.build()?;
let prewarming_pool = WorkerPool::from_builder(
rayon::ThreadPoolBuilder::new()
.num_threads(prewarming_threads)
.thread_name(|i| format!("prewarm-{i:02}")),
)?;
debug!(
default_threads,

View File

@@ -155,6 +155,7 @@ impl ProofWorkerHandle {
Factory: DatabaseProviderROFactory<Provider: TrieCursorFactory + HashedCursorFactory>
+ Clone
+ Send
+ Sync
+ 'static,
{
let (storage_work_tx, storage_work_rx) = unbounded::<StorageWorkerJob>();
@@ -180,30 +181,30 @@ impl ProofWorkerHandle {
"Spawning proof worker pools"
);
let storage_pool = runtime.proof_storage_worker_pool();
let task_ctx_for_storage = task_ctx.clone();
let cached_storage_roots_for_storage = cached_storage_roots.clone();
// broadcast blocks until all workers exit (channel close), so run on
// tokio's blocking pool.
let storage_rt = runtime.clone();
let storage_task_ctx = task_ctx.clone();
let storage_avail = storage_available_workers.clone();
let storage_roots = cached_storage_roots.clone();
runtime.spawn_blocking(move || {
let worker_id = AtomicUsize::new(0);
storage_rt.proof_storage_worker_pool().broadcast(storage_worker_count, |_| {
let worker_id = worker_id.fetch_add(1, Ordering::Relaxed);
let span = debug_span!(target: "trie::proof_task", "storage worker", ?worker_id);
let _guard = span.enter();
for worker_id in 0..storage_worker_count {
let span = debug_span!(target: "trie::proof_task", "storage worker", ?worker_id);
let task_ctx_clone = task_ctx_for_storage.clone();
let work_rx_clone = storage_work_rx.clone();
let storage_available_workers_clone = storage_available_workers.clone();
let cached_storage_roots = cached_storage_roots_for_storage.clone();
storage_pool.spawn(move || {
#[cfg(feature = "metrics")]
let metrics = ProofTaskTrieMetrics::default();
#[cfg(feature = "metrics")]
let cursor_metrics = ProofTaskCursorMetrics::new();
let _guard = span.enter();
let worker = StorageProofWorker::new(
task_ctx_clone,
work_rx_clone,
storage_task_ctx.clone(),
storage_work_rx.clone(),
worker_id,
storage_available_workers_clone,
cached_storage_roots,
storage_avail.clone(),
storage_roots.clone(),
#[cfg(feature = "metrics")]
metrics,
#[cfg(feature = "metrics")]
@@ -219,32 +220,30 @@ impl ProofWorkerHandle {
);
}
});
}
});
let account_pool = runtime.proof_account_worker_pool();
let account_rt = runtime.clone();
let account_tx = storage_work_tx.clone();
let account_avail = account_available_workers.clone();
runtime.spawn_blocking(move || {
let worker_id = AtomicUsize::new(0);
account_rt.proof_account_worker_pool().broadcast(account_worker_count, |_| {
let worker_id = worker_id.fetch_add(1, Ordering::Relaxed);
let span = debug_span!(target: "trie::proof_task", "account worker", ?worker_id);
let _guard = span.enter();
for worker_id in 0..account_worker_count {
let span = debug_span!(target: "trie::proof_task", "account worker", ?worker_id);
let task_ctx_clone = task_ctx.clone();
let work_rx_clone = account_work_rx.clone();
let storage_work_tx_clone = storage_work_tx.clone();
let account_available_workers_clone = account_available_workers.clone();
let cached_storage_roots = cached_storage_roots.clone();
account_pool.spawn(move || {
#[cfg(feature = "metrics")]
let metrics = ProofTaskTrieMetrics::default();
#[cfg(feature = "metrics")]
let cursor_metrics = ProofTaskCursorMetrics::new();
let _guard = span.enter();
let worker = AccountProofWorker::new(
task_ctx_clone,
work_rx_clone,
task_ctx.clone(),
account_work_rx.clone(),
worker_id,
storage_work_tx_clone,
account_available_workers_clone,
cached_storage_roots,
account_tx.clone(),
account_avail.clone(),
cached_storage_roots.clone(),
#[cfg(feature = "metrics")]
metrics,
#[cfg(feature = "metrics")]
@@ -260,7 +259,7 @@ impl ProofWorkerHandle {
);
}
});
}
});
Self {
storage_work_tx,