From 2eec519bf984cffb2bf4bd76012f4b3b5fe3e5e6 Mon Sep 17 00:00:00 2001 From: Georgios Konstantopoulos Date: Mon, 16 Feb 2026 15:46:51 -0800 Subject: [PATCH] feat(tasks): add WorkerPool with per-thread Worker state (#22154) Co-authored-by: Amp --- .../tree/src/tree/payload_processor/mod.rs | 1 + .../src/tree/payload_processor/prewarm.rs | 2 +- crates/tasks/src/lib.rs | 2 + crates/tasks/src/pool.rs | 254 +++++++++++++++++- crates/tasks/src/runtime.rs | 41 +-- crates/trie/parallel/src/proof_task.rs | 67 +++-- 6 files changed, 312 insertions(+), 55 deletions(-) diff --git a/crates/engine/tree/src/tree/payload_processor/mod.rs b/crates/engine/tree/src/tree/payload_processor/mod.rs index d012aa9e7a..ce32f90f65 100644 --- a/crates/engine/tree/src/tree/payload_processor/mod.rs +++ b/crates/engine/tree/src/tree/payload_processor/mod.rs @@ -276,6 +276,7 @@ where F: DatabaseProviderROFactory + Clone + Send + + Sync + 'static, { // start preparing transactions immediately diff --git a/crates/engine/tree/src/tree/payload_processor/prewarm.rs b/crates/engine/tree/src/tree/payload_processor/prewarm.rs index a432c1b3e8..b27a10defc 100644 --- a/crates/engine/tree/src/tree/payload_processor/prewarm.rs +++ b/crates/engine/tree/src/tree/payload_processor/prewarm.rs @@ -299,7 +299,7 @@ where ); let ctx = self.ctx.clone(); - self.executor.prewarming_pool().install(|| { + self.executor.prewarming_pool().install_fn(|| { bal.par_iter().for_each_init( || (ctx.clone(), None::>), |(ctx, provider), account| { diff --git a/crates/tasks/src/lib.rs b/crates/tasks/src/lib.rs index cd2f4c5525..29859611a0 100644 --- a/crates/tasks/src/lib.rs +++ b/crates/tasks/src/lib.rs @@ -36,6 +36,8 @@ pub mod shutdown; #[cfg(feature = "rayon")] pub mod pool; +#[cfg(feature = "rayon")] +pub use pool::{Worker, WorkerPool}; /// Lock-free ordered parallel iterator extension trait. #[cfg(feature = "rayon")] diff --git a/crates/tasks/src/pool.rs b/crates/tasks/src/pool.rs index 76087b71ef..f3cca7cf9a 100644 --- a/crates/tasks/src/pool.rs +++ b/crates/tasks/src/pool.rs @@ -1,10 +1,15 @@ //! Additional helpers for executing tracing calls use std::{ + any::Any, + cell::RefCell, future::Future, panic::{catch_unwind, AssertUnwindSafe}, pin::Pin, - sync::Arc, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, + }, task::{ready, Context, Poll}, thread, }; @@ -151,6 +156,168 @@ impl Future for BlockingTaskHandle { #[non_exhaustive] pub struct TokioBlockingTaskError; +thread_local! { + static WORKER: RefCell = const { RefCell::new(Worker::new()) }; +} + +/// A rayon thread pool with per-thread [`Worker`] state. +/// +/// Each thread in the pool has its own [`Worker`] that can hold arbitrary state via +/// [`Worker::init`]. The state is thread-local and accessible during [`install`](Self::install) +/// calls. +/// +/// The pool supports multiple init/clear cycles, allowing reuse of the same threads with +/// different state configurations. +#[derive(Debug)] +pub struct WorkerPool { + pool: rayon::ThreadPool, +} + +impl WorkerPool { + /// Creates a new `WorkerPool` with the given number of threads. + pub fn new(num_threads: usize) -> Result { + Self::from_builder(rayon::ThreadPoolBuilder::new().num_threads(num_threads)) + } + + /// Creates a new `WorkerPool` from a [`rayon::ThreadPoolBuilder`]. + pub fn from_builder( + builder: rayon::ThreadPoolBuilder, + ) -> Result { + Ok(Self { pool: builder.build()? }) + } + + /// Returns the total number of threads in the underlying rayon pool. + pub fn current_num_threads(&self) -> usize { + self.pool.current_num_threads() + } + + /// Runs a closure on `num_threads` threads in the pool, giving mutable access to each + /// thread's [`Worker`]. + /// + /// Use this to initialize or re-initialize per-thread state via [`Worker::init`]. + /// Only `num_threads` threads execute the closure; the rest skip it. + pub fn broadcast(&self, num_threads: usize, f: impl Fn(&mut Worker) + Sync) { + if num_threads >= self.pool.current_num_threads() { + // Fast path: run on every thread, no atomic coordination needed. + self.pool.broadcast(|_| { + WORKER.with_borrow_mut(|worker| f(worker)); + }); + } else { + let remaining = AtomicUsize::new(num_threads); + self.pool.broadcast(|_| { + // Atomically claim a slot; threads that can't decrement skip the closure. + let mut current = remaining.load(Ordering::Relaxed); + loop { + if current == 0 { + return; + } + match remaining.compare_exchange_weak( + current, + current - 1, + Ordering::Relaxed, + Ordering::Relaxed, + ) { + Ok(_) => break, + Err(actual) => current = actual, + } + } + WORKER.with_borrow_mut(|worker| f(worker)); + }); + } + } + + /// Clears the state on every thread in the pool. + pub fn clear(&self) { + self.pool.broadcast(|_| { + WORKER.with_borrow_mut(Worker::clear); + }); + } + + /// Runs a closure on the pool with access to the calling thread's [`Worker`]. + /// + /// All rayon parallelism (e.g. `par_iter`) spawned inside the closure executes on this pool. + /// Each thread can access its own [`Worker`] via the provided reference or through additional + /// [`WorkerPool::with_worker`] calls. + pub fn install(&self, f: impl FnOnce(&Worker) -> R + Send) -> R { + self.pool.install(|| WORKER.with_borrow(|worker| f(worker))) + } + + /// Runs a closure on the pool without worker state access. + /// + /// Like [`install`](Self::install) but for closures that don't need per-thread [`Worker`] + /// state. + pub fn install_fn(&self, f: impl FnOnce() -> R + Send) -> R { + self.pool.install(f) + } + + /// Spawns a closure on the pool. + pub fn spawn(&self, f: impl FnOnce() + Send + 'static) { + self.pool.spawn(f); + } + + /// Access the current thread's [`Worker`] from within an [`install`](Self::install) closure. + /// + /// This is useful for accessing the worker from inside `par_iter` where the initial `&Worker` + /// reference from `install` belongs to a different thread. + pub fn with_worker(f: impl FnOnce(&Worker) -> R) -> R { + WORKER.with_borrow(|worker| f(worker)) + } +} + +/// Per-thread state container for a [`WorkerPool`]. +/// +/// Holds a type-erased `Box` that can be initialized and accessed with concrete types +/// via [`init`](Self::init) and [`get`](Self::get). +#[derive(Debug, Default)] +pub struct Worker { + state: Option>, +} + +impl Worker { + /// Creates a new empty `Worker`. + const fn new() -> Self { + Self { state: None } + } + + /// Initializes the worker state. + /// + /// If state of type `T` already exists, passes `Some(&mut T)` to the closure so resources + /// can be reused. On first init, passes `None`. + pub fn init(&mut self, f: impl FnOnce(Option<&mut T>) -> T) { + let existing = + self.state.take().and_then(|mut b| b.downcast_mut::().is_some().then_some(b)); + + let new_state = match existing { + Some(mut boxed) => { + let r = boxed.downcast_mut::().expect("type checked above"); + *r = f(Some(r)); + boxed + } + None => Box::new(f(None)), + }; + + self.state = Some(new_state); + } + + /// Returns a reference to the state, downcasted to `T`. + /// + /// # Panics + /// + /// Panics if the worker has not been initialized or if the type does not match. + pub fn get(&self) -> &T { + self.state + .as_ref() + .expect("worker not initialized") + .downcast_ref::() + .expect("worker state type mismatch") + } + + /// Clears the worker state, dropping the contained value. + pub fn clear(&mut self) { + self.state = None; + } +} + #[cfg(test)] mod tests { use super::*; @@ -172,4 +339,89 @@ mod tests { let res = res.await; assert!(res.is_err()); } + + #[test] + fn worker_pool_init_and_access() { + let pool = WorkerPool::new(2).unwrap(); + + pool.broadcast(2, |worker| { + worker.init::>(|_| vec![1, 2, 3]); + }); + + let sum: u8 = pool.install(|worker| { + let v = worker.get::>(); + v.iter().sum() + }); + assert_eq!(sum, 6); + + pool.clear(); + } + + #[test] + fn worker_pool_reinit_reuses_resources() { + let pool = WorkerPool::new(1).unwrap(); + + pool.broadcast(1, |worker| { + worker.init::>(|existing| { + assert!(existing.is_none()); + vec![1, 2, 3] + }); + }); + + pool.broadcast(1, |worker| { + worker.init::>(|existing| { + let v = existing.expect("should have existing state"); + assert_eq!(v, &mut vec![1, 2, 3]); + v.push(4); + std::mem::take(v) + }); + }); + + let len = pool.install(|worker| worker.get::>().len()); + assert_eq!(len, 4); + + pool.clear(); + } + + #[test] + fn worker_pool_clear_and_reinit() { + let pool = WorkerPool::new(1).unwrap(); + + pool.broadcast(1, |worker| { + worker.init::(|_| 42); + }); + let val = pool.install(|worker| *worker.get::()); + assert_eq!(val, 42); + + pool.clear(); + + pool.broadcast(1, |worker| { + worker.init::(|_| "hello".to_string()); + }); + let val = pool.install(|worker| worker.get::().clone()); + assert_eq!(val, "hello"); + + pool.clear(); + } + + #[test] + fn worker_pool_par_iter_with_worker() { + use rayon::prelude::*; + + let pool = WorkerPool::new(2).unwrap(); + + pool.broadcast(2, |worker| { + worker.init::(|_| 10); + }); + + let results: Vec = pool.install(|_| { + (0u64..4) + .into_par_iter() + .map(|i| WorkerPool::with_worker(|w| i + *w.get::())) + .collect() + }); + assert_eq!(results, vec![10, 11, 12, 13]); + + pool.clear(); + } } diff --git a/crates/tasks/src/runtime.rs b/crates/tasks/src/runtime.rs index e0887a5b68..a644e43b83 100644 --- a/crates/tasks/src/runtime.rs +++ b/crates/tasks/src/runtime.rs @@ -7,7 +7,7 @@ //! - [`BlockingTaskGuard`] for rate-limiting expensive operations (with `rayon` feature) #[cfg(feature = "rayon")] -use crate::pool::{BlockingTaskGuard, BlockingTaskPool}; +use crate::pool::{BlockingTaskGuard, BlockingTaskPool, WorkerPool}; use crate::{ metrics::{IncCounterOnDrop, TaskExecutorMetrics}, shutdown::{GracefulShutdown, GracefulShutdownGuard, Shutdown}, @@ -263,13 +263,13 @@ struct RuntimeInner { blocking_guard: BlockingTaskGuard, /// Proof storage worker pool (trie storage proof computation). #[cfg(feature = "rayon")] - proof_storage_worker_pool: rayon::ThreadPool, + proof_storage_worker_pool: WorkerPool, /// Proof account worker pool (trie account proof computation). #[cfg(feature = "rayon")] - proof_account_worker_pool: rayon::ThreadPool, + proof_account_worker_pool: WorkerPool, /// Prewarming pool (execution prewarming workers). #[cfg(feature = "rayon")] - prewarming_pool: rayon::ThreadPool, + prewarming_pool: WorkerPool, /// Handle to the spawned [`TaskManager`] background task. /// The task monitors critical tasks for panics and fires the shutdown signal. /// Can be taken via [`Runtime::take_task_manager_handle`] to poll for panic errors. @@ -349,19 +349,19 @@ impl Runtime { /// Get the proof storage worker pool. #[cfg(feature = "rayon")] - pub fn proof_storage_worker_pool(&self) -> &rayon::ThreadPool { + pub fn proof_storage_worker_pool(&self) -> &WorkerPool { &self.0.proof_storage_worker_pool } /// Get the proof account worker pool. #[cfg(feature = "rayon")] - pub fn proof_account_worker_pool(&self) -> &rayon::ThreadPool { + pub fn proof_account_worker_pool(&self) -> &WorkerPool { &self.0.proof_account_worker_pool } /// Get the prewarming pool. #[cfg(feature = "rayon")] - pub fn prewarming_pool(&self) -> &rayon::ThreadPool { + pub fn prewarming_pool(&self) -> &WorkerPool { &self.0.prewarming_pool } } @@ -809,23 +809,26 @@ impl RuntimeBuilder { let proof_storage_worker_threads = config.rayon.proof_storage_worker_threads.unwrap_or(default_threads); - let proof_storage_worker_pool = rayon::ThreadPoolBuilder::new() - .num_threads(proof_storage_worker_threads) - .thread_name(|i| format!("proof-strg-{i:02}")) - .build()?; + let proof_storage_worker_pool = WorkerPool::from_builder( + rayon::ThreadPoolBuilder::new() + .num_threads(proof_storage_worker_threads) + .thread_name(|i| format!("proof-strg-{i:02}")), + )?; let proof_account_worker_threads = config.rayon.proof_account_worker_threads.unwrap_or(default_threads); - let proof_account_worker_pool = rayon::ThreadPoolBuilder::new() - .num_threads(proof_account_worker_threads) - .thread_name(|i| format!("proof-acct-{i:02}")) - .build()?; + let proof_account_worker_pool = WorkerPool::from_builder( + rayon::ThreadPoolBuilder::new() + .num_threads(proof_account_worker_threads) + .thread_name(|i| format!("proof-acct-{i:02}")), + )?; let prewarming_threads = config.rayon.prewarming_threads.unwrap_or(default_threads); - let prewarming_pool = rayon::ThreadPoolBuilder::new() - .num_threads(prewarming_threads) - .thread_name(|i| format!("prewarm-{i:02}")) - .build()?; + let prewarming_pool = WorkerPool::from_builder( + rayon::ThreadPoolBuilder::new() + .num_threads(prewarming_threads) + .thread_name(|i| format!("prewarm-{i:02}")), + )?; debug!( default_threads, diff --git a/crates/trie/parallel/src/proof_task.rs b/crates/trie/parallel/src/proof_task.rs index 879297c918..c876f21833 100644 --- a/crates/trie/parallel/src/proof_task.rs +++ b/crates/trie/parallel/src/proof_task.rs @@ -155,6 +155,7 @@ impl ProofWorkerHandle { Factory: DatabaseProviderROFactory + Clone + Send + + Sync + 'static, { let (storage_work_tx, storage_work_rx) = unbounded::(); @@ -180,30 +181,30 @@ impl ProofWorkerHandle { "Spawning proof worker pools" ); - let storage_pool = runtime.proof_storage_worker_pool(); - let task_ctx_for_storage = task_ctx.clone(); - let cached_storage_roots_for_storage = cached_storage_roots.clone(); + // broadcast blocks until all workers exit (channel close), so run on + // tokio's blocking pool. + let storage_rt = runtime.clone(); + let storage_task_ctx = task_ctx.clone(); + let storage_avail = storage_available_workers.clone(); + let storage_roots = cached_storage_roots.clone(); + runtime.spawn_blocking(move || { + let worker_id = AtomicUsize::new(0); + storage_rt.proof_storage_worker_pool().broadcast(storage_worker_count, |_| { + let worker_id = worker_id.fetch_add(1, Ordering::Relaxed); + let span = debug_span!(target: "trie::proof_task", "storage worker", ?worker_id); + let _guard = span.enter(); - for worker_id in 0..storage_worker_count { - let span = debug_span!(target: "trie::proof_task", "storage worker", ?worker_id); - let task_ctx_clone = task_ctx_for_storage.clone(); - let work_rx_clone = storage_work_rx.clone(); - let storage_available_workers_clone = storage_available_workers.clone(); - let cached_storage_roots = cached_storage_roots_for_storage.clone(); - - storage_pool.spawn(move || { #[cfg(feature = "metrics")] let metrics = ProofTaskTrieMetrics::default(); #[cfg(feature = "metrics")] let cursor_metrics = ProofTaskCursorMetrics::new(); - let _guard = span.enter(); let worker = StorageProofWorker::new( - task_ctx_clone, - work_rx_clone, + storage_task_ctx.clone(), + storage_work_rx.clone(), worker_id, - storage_available_workers_clone, - cached_storage_roots, + storage_avail.clone(), + storage_roots.clone(), #[cfg(feature = "metrics")] metrics, #[cfg(feature = "metrics")] @@ -219,32 +220,30 @@ impl ProofWorkerHandle { ); } }); - } + }); - let account_pool = runtime.proof_account_worker_pool(); + let account_rt = runtime.clone(); + let account_tx = storage_work_tx.clone(); + let account_avail = account_available_workers.clone(); + runtime.spawn_blocking(move || { + let worker_id = AtomicUsize::new(0); + account_rt.proof_account_worker_pool().broadcast(account_worker_count, |_| { + let worker_id = worker_id.fetch_add(1, Ordering::Relaxed); + let span = debug_span!(target: "trie::proof_task", "account worker", ?worker_id); + let _guard = span.enter(); - for worker_id in 0..account_worker_count { - let span = debug_span!(target: "trie::proof_task", "account worker", ?worker_id); - let task_ctx_clone = task_ctx.clone(); - let work_rx_clone = account_work_rx.clone(); - let storage_work_tx_clone = storage_work_tx.clone(); - let account_available_workers_clone = account_available_workers.clone(); - let cached_storage_roots = cached_storage_roots.clone(); - - account_pool.spawn(move || { #[cfg(feature = "metrics")] let metrics = ProofTaskTrieMetrics::default(); #[cfg(feature = "metrics")] let cursor_metrics = ProofTaskCursorMetrics::new(); - let _guard = span.enter(); let worker = AccountProofWorker::new( - task_ctx_clone, - work_rx_clone, + task_ctx.clone(), + account_work_rx.clone(), worker_id, - storage_work_tx_clone, - account_available_workers_clone, - cached_storage_roots, + account_tx.clone(), + account_avail.clone(), + cached_storage_roots.clone(), #[cfg(feature = "metrics")] metrics, #[cfg(feature = "metrics")] @@ -260,7 +259,7 @@ impl ProofWorkerHandle { ); } }); - } + }); Self { storage_work_tx,