mirror of
https://github.com/paradigmxyz/reth.git
synced 2026-02-19 03:04:27 -05:00
feat(tasks): add WorkerPool with per-thread Worker state (#22154)
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
committed by
GitHub
parent
02513ecf3b
commit
2eec519bf9
@@ -276,6 +276,7 @@ where
|
||||
F: DatabaseProviderROFactory<Provider: TrieCursorFactory + HashedCursorFactory>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static,
|
||||
{
|
||||
// start preparing transactions immediately
|
||||
|
||||
@@ -299,7 +299,7 @@ where
|
||||
);
|
||||
|
||||
let ctx = self.ctx.clone();
|
||||
self.executor.prewarming_pool().install(|| {
|
||||
self.executor.prewarming_pool().install_fn(|| {
|
||||
bal.par_iter().for_each_init(
|
||||
|| (ctx.clone(), None::<CachedStateProvider<reth_provider::StateProviderBox>>),
|
||||
|(ctx, provider), account| {
|
||||
|
||||
@@ -36,6 +36,8 @@ pub mod shutdown;
|
||||
|
||||
#[cfg(feature = "rayon")]
|
||||
pub mod pool;
|
||||
#[cfg(feature = "rayon")]
|
||||
pub use pool::{Worker, WorkerPool};
|
||||
|
||||
/// Lock-free ordered parallel iterator extension trait.
|
||||
#[cfg(feature = "rayon")]
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
//! Additional helpers for executing tracing calls
|
||||
|
||||
use std::{
|
||||
any::Any,
|
||||
cell::RefCell,
|
||||
future::Future,
|
||||
panic::{catch_unwind, AssertUnwindSafe},
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
},
|
||||
task::{ready, Context, Poll},
|
||||
thread,
|
||||
};
|
||||
@@ -151,6 +156,168 @@ impl<T> Future for BlockingTaskHandle<T> {
|
||||
#[non_exhaustive]
|
||||
pub struct TokioBlockingTaskError;
|
||||
|
||||
thread_local! {
|
||||
static WORKER: RefCell<Worker> = const { RefCell::new(Worker::new()) };
|
||||
}
|
||||
|
||||
/// A rayon thread pool with per-thread [`Worker`] state.
|
||||
///
|
||||
/// Each thread in the pool has its own [`Worker`] that can hold arbitrary state via
|
||||
/// [`Worker::init`]. The state is thread-local and accessible during [`install`](Self::install)
|
||||
/// calls.
|
||||
///
|
||||
/// The pool supports multiple init/clear cycles, allowing reuse of the same threads with
|
||||
/// different state configurations.
|
||||
#[derive(Debug)]
|
||||
pub struct WorkerPool {
|
||||
pool: rayon::ThreadPool,
|
||||
}
|
||||
|
||||
impl WorkerPool {
|
||||
/// Creates a new `WorkerPool` with the given number of threads.
|
||||
pub fn new(num_threads: usize) -> Result<Self, rayon::ThreadPoolBuildError> {
|
||||
Self::from_builder(rayon::ThreadPoolBuilder::new().num_threads(num_threads))
|
||||
}
|
||||
|
||||
/// Creates a new `WorkerPool` from a [`rayon::ThreadPoolBuilder`].
|
||||
pub fn from_builder(
|
||||
builder: rayon::ThreadPoolBuilder,
|
||||
) -> Result<Self, rayon::ThreadPoolBuildError> {
|
||||
Ok(Self { pool: builder.build()? })
|
||||
}
|
||||
|
||||
/// Returns the total number of threads in the underlying rayon pool.
|
||||
pub fn current_num_threads(&self) -> usize {
|
||||
self.pool.current_num_threads()
|
||||
}
|
||||
|
||||
/// Runs a closure on `num_threads` threads in the pool, giving mutable access to each
|
||||
/// thread's [`Worker`].
|
||||
///
|
||||
/// Use this to initialize or re-initialize per-thread state via [`Worker::init`].
|
||||
/// Only `num_threads` threads execute the closure; the rest skip it.
|
||||
pub fn broadcast(&self, num_threads: usize, f: impl Fn(&mut Worker) + Sync) {
|
||||
if num_threads >= self.pool.current_num_threads() {
|
||||
// Fast path: run on every thread, no atomic coordination needed.
|
||||
self.pool.broadcast(|_| {
|
||||
WORKER.with_borrow_mut(|worker| f(worker));
|
||||
});
|
||||
} else {
|
||||
let remaining = AtomicUsize::new(num_threads);
|
||||
self.pool.broadcast(|_| {
|
||||
// Atomically claim a slot; threads that can't decrement skip the closure.
|
||||
let mut current = remaining.load(Ordering::Relaxed);
|
||||
loop {
|
||||
if current == 0 {
|
||||
return;
|
||||
}
|
||||
match remaining.compare_exchange_weak(
|
||||
current,
|
||||
current - 1,
|
||||
Ordering::Relaxed,
|
||||
Ordering::Relaxed,
|
||||
) {
|
||||
Ok(_) => break,
|
||||
Err(actual) => current = actual,
|
||||
}
|
||||
}
|
||||
WORKER.with_borrow_mut(|worker| f(worker));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Clears the state on every thread in the pool.
|
||||
pub fn clear(&self) {
|
||||
self.pool.broadcast(|_| {
|
||||
WORKER.with_borrow_mut(Worker::clear);
|
||||
});
|
||||
}
|
||||
|
||||
/// Runs a closure on the pool with access to the calling thread's [`Worker`].
|
||||
///
|
||||
/// All rayon parallelism (e.g. `par_iter`) spawned inside the closure executes on this pool.
|
||||
/// Each thread can access its own [`Worker`] via the provided reference or through additional
|
||||
/// [`WorkerPool::with_worker`] calls.
|
||||
pub fn install<R: Send>(&self, f: impl FnOnce(&Worker) -> R + Send) -> R {
|
||||
self.pool.install(|| WORKER.with_borrow(|worker| f(worker)))
|
||||
}
|
||||
|
||||
/// Runs a closure on the pool without worker state access.
|
||||
///
|
||||
/// Like [`install`](Self::install) but for closures that don't need per-thread [`Worker`]
|
||||
/// state.
|
||||
pub fn install_fn<R: Send>(&self, f: impl FnOnce() -> R + Send) -> R {
|
||||
self.pool.install(f)
|
||||
}
|
||||
|
||||
/// Spawns a closure on the pool.
|
||||
pub fn spawn(&self, f: impl FnOnce() + Send + 'static) {
|
||||
self.pool.spawn(f);
|
||||
}
|
||||
|
||||
/// Access the current thread's [`Worker`] from within an [`install`](Self::install) closure.
|
||||
///
|
||||
/// This is useful for accessing the worker from inside `par_iter` where the initial `&Worker`
|
||||
/// reference from `install` belongs to a different thread.
|
||||
pub fn with_worker<R>(f: impl FnOnce(&Worker) -> R) -> R {
|
||||
WORKER.with_borrow(|worker| f(worker))
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-thread state container for a [`WorkerPool`].
|
||||
///
|
||||
/// Holds a type-erased `Box<dyn Any>` that can be initialized and accessed with concrete types
|
||||
/// via [`init`](Self::init) and [`get`](Self::get).
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Worker {
|
||||
state: Option<Box<dyn Any>>,
|
||||
}
|
||||
|
||||
impl Worker {
|
||||
/// Creates a new empty `Worker`.
|
||||
const fn new() -> Self {
|
||||
Self { state: None }
|
||||
}
|
||||
|
||||
/// Initializes the worker state.
|
||||
///
|
||||
/// If state of type `T` already exists, passes `Some(&mut T)` to the closure so resources
|
||||
/// can be reused. On first init, passes `None`.
|
||||
pub fn init<T: 'static>(&mut self, f: impl FnOnce(Option<&mut T>) -> T) {
|
||||
let existing =
|
||||
self.state.take().and_then(|mut b| b.downcast_mut::<T>().is_some().then_some(b));
|
||||
|
||||
let new_state = match existing {
|
||||
Some(mut boxed) => {
|
||||
let r = boxed.downcast_mut::<T>().expect("type checked above");
|
||||
*r = f(Some(r));
|
||||
boxed
|
||||
}
|
||||
None => Box::new(f(None)),
|
||||
};
|
||||
|
||||
self.state = Some(new_state);
|
||||
}
|
||||
|
||||
/// Returns a reference to the state, downcasted to `T`.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the worker has not been initialized or if the type does not match.
|
||||
pub fn get<T: 'static>(&self) -> &T {
|
||||
self.state
|
||||
.as_ref()
|
||||
.expect("worker not initialized")
|
||||
.downcast_ref::<T>()
|
||||
.expect("worker state type mismatch")
|
||||
}
|
||||
|
||||
/// Clears the worker state, dropping the contained value.
|
||||
pub fn clear(&mut self) {
|
||||
self.state = None;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@@ -172,4 +339,89 @@ mod tests {
|
||||
let res = res.await;
|
||||
assert!(res.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn worker_pool_init_and_access() {
|
||||
let pool = WorkerPool::new(2).unwrap();
|
||||
|
||||
pool.broadcast(2, |worker| {
|
||||
worker.init::<Vec<u8>>(|_| vec![1, 2, 3]);
|
||||
});
|
||||
|
||||
let sum: u8 = pool.install(|worker| {
|
||||
let v = worker.get::<Vec<u8>>();
|
||||
v.iter().sum()
|
||||
});
|
||||
assert_eq!(sum, 6);
|
||||
|
||||
pool.clear();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn worker_pool_reinit_reuses_resources() {
|
||||
let pool = WorkerPool::new(1).unwrap();
|
||||
|
||||
pool.broadcast(1, |worker| {
|
||||
worker.init::<Vec<u8>>(|existing| {
|
||||
assert!(existing.is_none());
|
||||
vec![1, 2, 3]
|
||||
});
|
||||
});
|
||||
|
||||
pool.broadcast(1, |worker| {
|
||||
worker.init::<Vec<u8>>(|existing| {
|
||||
let v = existing.expect("should have existing state");
|
||||
assert_eq!(v, &mut vec![1, 2, 3]);
|
||||
v.push(4);
|
||||
std::mem::take(v)
|
||||
});
|
||||
});
|
||||
|
||||
let len = pool.install(|worker| worker.get::<Vec<u8>>().len());
|
||||
assert_eq!(len, 4);
|
||||
|
||||
pool.clear();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn worker_pool_clear_and_reinit() {
|
||||
let pool = WorkerPool::new(1).unwrap();
|
||||
|
||||
pool.broadcast(1, |worker| {
|
||||
worker.init::<u64>(|_| 42);
|
||||
});
|
||||
let val = pool.install(|worker| *worker.get::<u64>());
|
||||
assert_eq!(val, 42);
|
||||
|
||||
pool.clear();
|
||||
|
||||
pool.broadcast(1, |worker| {
|
||||
worker.init::<String>(|_| "hello".to_string());
|
||||
});
|
||||
let val = pool.install(|worker| worker.get::<String>().clone());
|
||||
assert_eq!(val, "hello");
|
||||
|
||||
pool.clear();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn worker_pool_par_iter_with_worker() {
|
||||
use rayon::prelude::*;
|
||||
|
||||
let pool = WorkerPool::new(2).unwrap();
|
||||
|
||||
pool.broadcast(2, |worker| {
|
||||
worker.init::<u64>(|_| 10);
|
||||
});
|
||||
|
||||
let results: Vec<u64> = pool.install(|_| {
|
||||
(0u64..4)
|
||||
.into_par_iter()
|
||||
.map(|i| WorkerPool::with_worker(|w| i + *w.get::<u64>()))
|
||||
.collect()
|
||||
});
|
||||
assert_eq!(results, vec![10, 11, 12, 13]);
|
||||
|
||||
pool.clear();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
//! - [`BlockingTaskGuard`] for rate-limiting expensive operations (with `rayon` feature)
|
||||
|
||||
#[cfg(feature = "rayon")]
|
||||
use crate::pool::{BlockingTaskGuard, BlockingTaskPool};
|
||||
use crate::pool::{BlockingTaskGuard, BlockingTaskPool, WorkerPool};
|
||||
use crate::{
|
||||
metrics::{IncCounterOnDrop, TaskExecutorMetrics},
|
||||
shutdown::{GracefulShutdown, GracefulShutdownGuard, Shutdown},
|
||||
@@ -263,13 +263,13 @@ struct RuntimeInner {
|
||||
blocking_guard: BlockingTaskGuard,
|
||||
/// Proof storage worker pool (trie storage proof computation).
|
||||
#[cfg(feature = "rayon")]
|
||||
proof_storage_worker_pool: rayon::ThreadPool,
|
||||
proof_storage_worker_pool: WorkerPool,
|
||||
/// Proof account worker pool (trie account proof computation).
|
||||
#[cfg(feature = "rayon")]
|
||||
proof_account_worker_pool: rayon::ThreadPool,
|
||||
proof_account_worker_pool: WorkerPool,
|
||||
/// Prewarming pool (execution prewarming workers).
|
||||
#[cfg(feature = "rayon")]
|
||||
prewarming_pool: rayon::ThreadPool,
|
||||
prewarming_pool: WorkerPool,
|
||||
/// Handle to the spawned [`TaskManager`] background task.
|
||||
/// The task monitors critical tasks for panics and fires the shutdown signal.
|
||||
/// Can be taken via [`Runtime::take_task_manager_handle`] to poll for panic errors.
|
||||
@@ -349,19 +349,19 @@ impl Runtime {
|
||||
|
||||
/// Get the proof storage worker pool.
|
||||
#[cfg(feature = "rayon")]
|
||||
pub fn proof_storage_worker_pool(&self) -> &rayon::ThreadPool {
|
||||
pub fn proof_storage_worker_pool(&self) -> &WorkerPool {
|
||||
&self.0.proof_storage_worker_pool
|
||||
}
|
||||
|
||||
/// Get the proof account worker pool.
|
||||
#[cfg(feature = "rayon")]
|
||||
pub fn proof_account_worker_pool(&self) -> &rayon::ThreadPool {
|
||||
pub fn proof_account_worker_pool(&self) -> &WorkerPool {
|
||||
&self.0.proof_account_worker_pool
|
||||
}
|
||||
|
||||
/// Get the prewarming pool.
|
||||
#[cfg(feature = "rayon")]
|
||||
pub fn prewarming_pool(&self) -> &rayon::ThreadPool {
|
||||
pub fn prewarming_pool(&self) -> &WorkerPool {
|
||||
&self.0.prewarming_pool
|
||||
}
|
||||
}
|
||||
@@ -809,23 +809,26 @@ impl RuntimeBuilder {
|
||||
|
||||
let proof_storage_worker_threads =
|
||||
config.rayon.proof_storage_worker_threads.unwrap_or(default_threads);
|
||||
let proof_storage_worker_pool = rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(proof_storage_worker_threads)
|
||||
.thread_name(|i| format!("proof-strg-{i:02}"))
|
||||
.build()?;
|
||||
let proof_storage_worker_pool = WorkerPool::from_builder(
|
||||
rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(proof_storage_worker_threads)
|
||||
.thread_name(|i| format!("proof-strg-{i:02}")),
|
||||
)?;
|
||||
|
||||
let proof_account_worker_threads =
|
||||
config.rayon.proof_account_worker_threads.unwrap_or(default_threads);
|
||||
let proof_account_worker_pool = rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(proof_account_worker_threads)
|
||||
.thread_name(|i| format!("proof-acct-{i:02}"))
|
||||
.build()?;
|
||||
let proof_account_worker_pool = WorkerPool::from_builder(
|
||||
rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(proof_account_worker_threads)
|
||||
.thread_name(|i| format!("proof-acct-{i:02}")),
|
||||
)?;
|
||||
|
||||
let prewarming_threads = config.rayon.prewarming_threads.unwrap_or(default_threads);
|
||||
let prewarming_pool = rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(prewarming_threads)
|
||||
.thread_name(|i| format!("prewarm-{i:02}"))
|
||||
.build()?;
|
||||
let prewarming_pool = WorkerPool::from_builder(
|
||||
rayon::ThreadPoolBuilder::new()
|
||||
.num_threads(prewarming_threads)
|
||||
.thread_name(|i| format!("prewarm-{i:02}")),
|
||||
)?;
|
||||
|
||||
debug!(
|
||||
default_threads,
|
||||
|
||||
@@ -155,6 +155,7 @@ impl ProofWorkerHandle {
|
||||
Factory: DatabaseProviderROFactory<Provider: TrieCursorFactory + HashedCursorFactory>
|
||||
+ Clone
|
||||
+ Send
|
||||
+ Sync
|
||||
+ 'static,
|
||||
{
|
||||
let (storage_work_tx, storage_work_rx) = unbounded::<StorageWorkerJob>();
|
||||
@@ -180,30 +181,30 @@ impl ProofWorkerHandle {
|
||||
"Spawning proof worker pools"
|
||||
);
|
||||
|
||||
let storage_pool = runtime.proof_storage_worker_pool();
|
||||
let task_ctx_for_storage = task_ctx.clone();
|
||||
let cached_storage_roots_for_storage = cached_storage_roots.clone();
|
||||
// broadcast blocks until all workers exit (channel close), so run on
|
||||
// tokio's blocking pool.
|
||||
let storage_rt = runtime.clone();
|
||||
let storage_task_ctx = task_ctx.clone();
|
||||
let storage_avail = storage_available_workers.clone();
|
||||
let storage_roots = cached_storage_roots.clone();
|
||||
runtime.spawn_blocking(move || {
|
||||
let worker_id = AtomicUsize::new(0);
|
||||
storage_rt.proof_storage_worker_pool().broadcast(storage_worker_count, |_| {
|
||||
let worker_id = worker_id.fetch_add(1, Ordering::Relaxed);
|
||||
let span = debug_span!(target: "trie::proof_task", "storage worker", ?worker_id);
|
||||
let _guard = span.enter();
|
||||
|
||||
for worker_id in 0..storage_worker_count {
|
||||
let span = debug_span!(target: "trie::proof_task", "storage worker", ?worker_id);
|
||||
let task_ctx_clone = task_ctx_for_storage.clone();
|
||||
let work_rx_clone = storage_work_rx.clone();
|
||||
let storage_available_workers_clone = storage_available_workers.clone();
|
||||
let cached_storage_roots = cached_storage_roots_for_storage.clone();
|
||||
|
||||
storage_pool.spawn(move || {
|
||||
#[cfg(feature = "metrics")]
|
||||
let metrics = ProofTaskTrieMetrics::default();
|
||||
#[cfg(feature = "metrics")]
|
||||
let cursor_metrics = ProofTaskCursorMetrics::new();
|
||||
|
||||
let _guard = span.enter();
|
||||
let worker = StorageProofWorker::new(
|
||||
task_ctx_clone,
|
||||
work_rx_clone,
|
||||
storage_task_ctx.clone(),
|
||||
storage_work_rx.clone(),
|
||||
worker_id,
|
||||
storage_available_workers_clone,
|
||||
cached_storage_roots,
|
||||
storage_avail.clone(),
|
||||
storage_roots.clone(),
|
||||
#[cfg(feature = "metrics")]
|
||||
metrics,
|
||||
#[cfg(feature = "metrics")]
|
||||
@@ -219,32 +220,30 @@ impl ProofWorkerHandle {
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
let account_pool = runtime.proof_account_worker_pool();
|
||||
let account_rt = runtime.clone();
|
||||
let account_tx = storage_work_tx.clone();
|
||||
let account_avail = account_available_workers.clone();
|
||||
runtime.spawn_blocking(move || {
|
||||
let worker_id = AtomicUsize::new(0);
|
||||
account_rt.proof_account_worker_pool().broadcast(account_worker_count, |_| {
|
||||
let worker_id = worker_id.fetch_add(1, Ordering::Relaxed);
|
||||
let span = debug_span!(target: "trie::proof_task", "account worker", ?worker_id);
|
||||
let _guard = span.enter();
|
||||
|
||||
for worker_id in 0..account_worker_count {
|
||||
let span = debug_span!(target: "trie::proof_task", "account worker", ?worker_id);
|
||||
let task_ctx_clone = task_ctx.clone();
|
||||
let work_rx_clone = account_work_rx.clone();
|
||||
let storage_work_tx_clone = storage_work_tx.clone();
|
||||
let account_available_workers_clone = account_available_workers.clone();
|
||||
let cached_storage_roots = cached_storage_roots.clone();
|
||||
|
||||
account_pool.spawn(move || {
|
||||
#[cfg(feature = "metrics")]
|
||||
let metrics = ProofTaskTrieMetrics::default();
|
||||
#[cfg(feature = "metrics")]
|
||||
let cursor_metrics = ProofTaskCursorMetrics::new();
|
||||
|
||||
let _guard = span.enter();
|
||||
let worker = AccountProofWorker::new(
|
||||
task_ctx_clone,
|
||||
work_rx_clone,
|
||||
task_ctx.clone(),
|
||||
account_work_rx.clone(),
|
||||
worker_id,
|
||||
storage_work_tx_clone,
|
||||
account_available_workers_clone,
|
||||
cached_storage_roots,
|
||||
account_tx.clone(),
|
||||
account_avail.clone(),
|
||||
cached_storage_roots.clone(),
|
||||
#[cfg(feature = "metrics")]
|
||||
metrics,
|
||||
#[cfg(feature = "metrics")]
|
||||
@@ -260,7 +259,7 @@ impl ProofWorkerHandle {
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
Self {
|
||||
storage_work_tx,
|
||||
|
||||
Reference in New Issue
Block a user