Timestamp normalization

This commit is contained in:
Connor Fitzgerald
2025-03-26 18:29:07 -04:00
parent 7da699608d
commit 6a986f4bc4
19 changed files with 1179 additions and 5 deletions

1
Cargo.lock generated
View File

@@ -4863,6 +4863,7 @@ dependencies = [
"js-sys",
"libtest-mimic",
"log",
"nanorand",
"nv-flip",
"parking_lot",
"png",

View File

@@ -49,6 +49,7 @@ itertools.workspace = true
image.workspace = true
libtest-mimic.workspace = true
log.workspace = true
nanorand.workspace = true
parking_lot.workspace = true
png.workspace = true
pollster.workspace = true

View File

@@ -56,6 +56,8 @@ mod texture_binding;
mod texture_blit;
mod texture_bounds;
mod texture_view_creation;
mod timestamp_normalization;
mod timestamp_query;
mod transfer;
mod transition_resources;
mod vertex_formats;

View File

@@ -1,8 +1,14 @@
use wgpu::InstanceFlags;
use wgpu_test::{gpu_test, FailureCase, GpuTestConfiguration, TestParameters};
#[gpu_test]
static OCCLUSION_QUERY: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(TestParameters::default().expect_fail(FailureCase::webgl2()))
.parameters(
TestParameters::default()
.expect_fail(FailureCase::webgl2())
// Ensure timestamp normalization does not interfere with occlusion query results
.instance_flags(InstanceFlags::AUTOMATIC_TIMESTAMP_NORMALIZATION),
)
.run_async(|ctx| async move {
// Create depth texture
let depth_texture = ctx.device.create_texture(&wgpu::TextureDescriptor {

View File

@@ -0,0 +1 @@
mod utils;

View File

@@ -0,0 +1,22 @@
// Must have "wgpu-core/src/timestamp_normalization/common.wgsl"
// preprocessed before this file's contents.
struct ShiftRight96 {
value: Uint96,
shift: u32,
}
@group(0) @binding(0)
var<storage> input: array<ShiftRight96>;
@group(0) @binding(1)
var<storage, read_write> output: array<Uint96>;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) id: vec3u) {
let index = id.x;
let input = input[index];
output[index] = shift_right_96(input.value, input.shift);
}

View File

@@ -0,0 +1,23 @@
// Must have "wgpu-core/src/timestamp_normalization/common.wgsl"
// preprocessed before this file's contents.
struct U64MulU32Input {
left: Uint64,
right: u32,
_padding: u32,
}
@group(0) @binding(0)
var<storage> input: array<U64MulU32Input>;
@group(0) @binding(1)
var<storage, read_write> output: array<Uint96>;
@compute @workgroup_size(256)
fn main(@builtin(global_invocation_id) id: vec3u) {
let index = id.x;
let input = input[index];
output[index] = u64_mul_u32(input.left, input.right);
}

View File

@@ -0,0 +1,285 @@
//! Tests for the timestamp normalization algorithm's utility functions.
//!
//! Because they involve multiple kinds of hand-rolled math operations,
//! we do testing to ensure the overall operation (which is very simple)
//! works correctly.
use nanorand::Rng;
use wgpu::{util::DeviceExt, Limits};
use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext};
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct Uint96(u32, u32, u32);
impl Uint96 {
fn from_u128(value: u128) -> Self {
let a = (value & 0xFFFF_FFFF) as u32;
let b = ((value >> 32) & 0xFFFF_FFFF) as u32;
let c = ((value >> 64) & 0xFFFF_FFFF) as u32;
Self(a, b, c)
}
fn as_u128(&self) -> u128 {
((self.2 as u128) << 64) | ((self.1 as u128) << 32) | (self.0 as u128)
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct U64MulU32Input {
left: u64,
right: u32,
_pad: u32,
}
impl U64MulU32Input {
fn new(left: u64, right: u32) -> Self {
Self {
left,
right,
_pad: 0,
}
}
}
fn assert_u64_mul_u32(left: u64, right: u32, computed: Uint96) {
let real = left as u128 * right as u128;
let computed = computed.as_u128();
assert_eq!(
computed, real,
"{left} * {right} should be {real} but is {}",
computed
);
}
#[gpu_test]
static U64_MUL_U32: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.test_features_limits()
.limits(Limits {
max_storage_buffer_binding_size: 256 * 1024 * 1024,
..Limits::downlevel_defaults()
}),
)
.run_sync(test_u64_mul_u32);
fn test_u64_mul_u32(ctx: TestingContext) {
const TOTAL_RANDOM_INPUTS: usize = 1_000_000;
const MANUAL_INPUTS: usize = 2;
const TOTAL_INPUTS: usize = TOTAL_RANDOM_INPUTS + MANUAL_INPUTS;
let mut inputs = Vec::with_capacity(TOTAL_INPUTS);
inputs.push(U64MulU32Input::new(2, 2));
inputs.push(U64MulU32Input::new(u64::MAX, u32::MAX));
// Smoke test the algorithm by generating 1M random inputs, and checking the results.
let mut generator = nanorand::WyRand::new_seed(0xDEAD_BEEF);
for _ in 0..TOTAL_RANDOM_INPUTS {
let left = generator.generate::<u64>();
let right = generator.generate::<u32>();
inputs.push(U64MulU32Input::new(left, right));
}
assert_eq!(TOTAL_INPUTS, inputs.len());
let output_bytes = process_shader(
ctx,
bytemuck::cast_slice(&inputs),
include_str!("u64_mul_u32.wgsl"),
);
let output_values = bytemuck::pod_collect_to_vec(&output_bytes);
for (&input, &output) in inputs.iter().zip(output_values.iter()) {
assert_u64_mul_u32(input.left, input.right, output);
}
}
#[repr(C)]
#[derive(Debug, Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct ShiftRightU96Input {
value: Uint96,
shift: u32,
}
impl ShiftRightU96Input {
fn new(value: u128, shift: u32) -> Self {
assert!(shift <= 32);
assert!(value >> 96 == 0);
Self {
value: Uint96::from_u128(value),
shift,
}
}
}
fn assert_shift_right_u96(value: Uint96, shift: u32, computed: Uint96) {
let value = value.as_u128();
let real = value >> shift;
let computed = computed.as_u128();
assert_eq!(
computed, real,
"{value:X} >> {shift} should be {real:X} but is {computed:X}",
);
}
#[gpu_test]
static SHIFT_RIGHT_U96: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.test_features_limits()
.limits(Limits {
max_storage_buffer_binding_size: 256 * 1024 * 1024,
..Limits::downlevel_defaults()
}),
)
.run_sync(test_shift_right_u96);
fn test_shift_right_u96(ctx: TestingContext) {
const TOTAL_RANDOM_INPUTS: usize = 1_000_000;
const TOTAL_SHIFT_INPUTS: usize = 33;
const MANUAL_INPUTS: usize = 1;
const TOTAL_INPUTS: usize = TOTAL_RANDOM_INPUTS + TOTAL_SHIFT_INPUTS + MANUAL_INPUTS;
let mut inputs = Vec::with_capacity(TOTAL_INPUTS);
inputs.push(ShiftRightU96Input::new(1, 1));
for shift in 0..TOTAL_SHIFT_INPUTS {
// 96 bit number with a visually recognizable pattern.
const INTERESTING_NUMBER: u128 = 0x1234_5678_9ABC_DEF0_1234_5678;
inputs.push(ShiftRightU96Input::new(INTERESTING_NUMBER, shift as u32));
}
// Smoke test the algorithm by generating 1M random inputs, and checking the results.
let mut generator = nanorand::WyRand::new_seed(0xDEAD_BEEF);
for _ in 0..TOTAL_RANDOM_INPUTS {
// nanorand doesn't have generate_range for u128, so just chop the top bits off.
let value = generator.generate::<u128>() >> 32;
let shift = generator.generate_range(0..=32);
inputs.push(ShiftRightU96Input::new(value, shift));
}
assert_eq!(TOTAL_INPUTS, inputs.len());
let output_bytes = process_shader(
ctx,
bytemuck::cast_slice(&inputs),
include_str!("shift_right_u96.wgsl"),
);
let output_values = bytemuck::pod_collect_to_vec(&output_bytes);
for (&input, &output) in inputs.iter().zip(output_values.iter()) {
assert_shift_right_u96(input.value, input.shift, output);
}
}
fn process_shader(ctx: TestingContext, inputs: &[u8], entry_point_src: &str) -> Vec<u8> {
let common_src = include_str!("../../../../wgpu-core/src/timestamp_normalization/common.wgsl");
let full_source = format!("{common_src}\n{entry_point_src}");
let shader_module = ctx
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("u64_mul_u32"),
source: wgpu::ShaderSource::Wgsl(full_source.into()),
});
let pipeline = ctx
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("u64_mul_u32"),
layout: None,
module: &shader_module,
entry_point: None,
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
});
let input_buffer = ctx
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Input Buffer"),
contents: inputs,
usage: wgpu::BufferUsages::STORAGE,
});
let output_size = (size_of::<Uint96>() * inputs.len()) as u64;
let output_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Output Buffer"),
size: output_size,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let pulldown_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Pulldown Buffer"),
size: output_size,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let bgl = pipeline.get_bind_group_layout(0);
let bg = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Bind Group"),
layout: &bgl,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: input_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: output_buffer.as_entire_binding(),
},
],
});
let mut encoder = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Compute Encoder"),
});
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("Compute Pass"),
timestamp_writes: None,
});
cpass.set_pipeline(&pipeline);
cpass.set_bind_group(0, &bg, &[]);
cpass.dispatch_workgroups(inputs.len().div_ceil(256) as u32, 1, 1);
drop(cpass);
encoder.copy_buffer_to_buffer(&output_buffer, 0, &pulldown_buffer, 0, output_size);
ctx.queue.submit([encoder.finish()]);
pulldown_buffer.map_async(wgpu::MapMode::Read, .., |_| {});
ctx.device.poll(wgpu::PollType::Wait).unwrap();
let value = pulldown_buffer.get_mapped_range(..).to_vec();
value
}

View File

@@ -0,0 +1,132 @@
use wgpu::{
util::DeviceExt, ComputePassTimestampWrites, Features, InstanceFlags,
QUERY_RESOLVE_BUFFER_ALIGNMENT,
};
use wgpu_test::{gpu_test, FailureCase, GpuTestConfiguration, TestParameters, TestingContext};
const SHADER: &str = r#"
@compute @workgroup_size(1)
fn main() {
return;
}
"#;
const ITERATIONS: u32 = 10;
const QUERIES_PER_ITERATION: u32 = 2;
const TOTAL_QUERIES: u32 = QUERIES_PER_ITERATION * ITERATIONS;
#[gpu_test]
static TIMESTAMP_QUERY: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.expect_fail(FailureCase::webgl2())
.test_features_limits()
.features(Features::TIMESTAMP_QUERY)
// Ensure timestamp normalization functions correctly
.instance_flags(InstanceFlags::AUTOMATIC_TIMESTAMP_NORMALIZATION),
)
.run_sync(timestamp_query);
fn timestamp_query(ctx: TestingContext) {
// Setup pipeline using a simple shader with hardcoded vertices
let shader = ctx
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("timestamp query shader"),
source: wgpu::ShaderSource::Wgsl(SHADER.into()),
});
let pipeline = ctx
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Pipeline"),
layout: None,
module: &shader,
entry_point: None,
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
});
// Create timestamp query set
let query_set = ctx.device.create_query_set(&wgpu::QuerySetDescriptor {
label: Some("Query set"),
ty: wgpu::QueryType::Timestamp,
count: TOTAL_QUERIES,
});
let mut encoder = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
for i in 0..ITERATIONS {
let base_index = i * QUERIES_PER_ITERATION;
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("compute pass"),
timestamp_writes: Some(ComputePassTimestampWrites {
query_set: &query_set,
beginning_of_pass_write_index: Some(base_index),
end_of_pass_write_index: Some(base_index + 1),
}),
});
compute_pass.set_pipeline(&pipeline);
compute_pass.dispatch_workgroups(1, 1, 1);
}
let buffer_size = QUERY_RESOLVE_BUFFER_ALIGNMENT * TOTAL_QUERIES as u64;
let init_constant = 0x0123_4567_89AB_CDEFu64;
let init_data = vec![init_constant; buffer_size as usize / 8];
// Resolve query set to buffer
let query_buffer = ctx
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Query buffer"),
contents: bytemuck::cast_slice(&init_data),
usage: wgpu::BufferUsages::QUERY_RESOLVE | wgpu::BufferUsages::COPY_SRC,
});
for i in 0..ITERATIONS {
let start_query = i * QUERIES_PER_ITERATION;
let end_query = start_query + QUERIES_PER_ITERATION;
let buffer_offset = i as u64 * QUERY_RESOLVE_BUFFER_ALIGNMENT;
encoder.resolve_query_set(
&query_set,
start_query..end_query,
&query_buffer,
buffer_offset,
);
}
let mapping_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Mapping buffer"),
size: query_buffer.size(),
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
encoder.copy_buffer_to_buffer(&query_buffer, 0, &mapping_buffer, 0, query_buffer.size());
ctx.queue.submit(Some(encoder.finish()));
mapping_buffer
.slice(..)
.map_async(wgpu::MapMode::Read, |_| ());
ctx.device.poll(wgpu::PollType::wait()).unwrap();
let query_buffer_view = mapping_buffer.slice(..).get_mapped_range();
let query_data: &[u64] = bytemuck::cast_slice(&query_buffer_view);
for i in 0..ITERATIONS {
let byte_offset = i as u64 * QUERY_RESOLVE_BUFFER_ALIGNMENT;
let query_offset = byte_offset / 8;
// WebGPU does not define the value of the timestamp queries. They unfortunately
// can be `0` in some situations. However, we should expect that some value
// has been written, and the odds of it being exactly `init_constant` are vanishingly low.
assert_ne!(query_data[query_offset as usize], init_constant);
assert_ne!(query_data[query_offset as usize + 1], init_constant);
}
}

View File

@@ -465,6 +465,26 @@ impl Global {
);
}
if matches!(query_set.desc.ty, wgt::QueryType::Timestamp) {
// Timestamp normalization is only needed for timestamps.
cmd_buf
.device
.timestamp_normalizer
.get()
.unwrap()
.normalize(
&snatch_guard,
raw_encoder,
&mut cmd_buf_data.trackers.buffers,
dst_buffer
.timestamp_normalization_bind_group
.get(&snatch_guard)
.unwrap(),
&dst_buffer,
start_query..end_query,
);
}
cmd_buf_data.trackers.query_sets.insert_single(query_set);
cmd_buf_data_guard.mark_successful();

View File

@@ -1488,6 +1488,11 @@ impl Global {
pub fn queue_get_timestamp_period(&self, queue_id: QueueId) -> f32 {
let queue = self.hub.queues.get(queue_id);
if queue.device.timestamp_normalizer.get().unwrap().enabled() {
return 1.0;
}
queue.get_timestamp_period()
}

View File

@@ -35,7 +35,7 @@ use crate::{
BufferInitTracker, BufferInitTrackerAction, MemoryInitKind, TextureInitRange,
TextureInitTrackerAction,
},
instance::Adapter,
instance::{Adapter, RequestDeviceError},
lock::{rank, Mutex, RwLock},
pipeline,
pool::ResourcePool,
@@ -45,6 +45,7 @@ use crate::{
},
resource_log,
snatch::{SnatchGuard, SnatchLock, Snatchable},
timestamp_normalization::TIMESTAMP_NORMALIZATION_BUFFER_USES,
track::{BindGroupStates, DeviceTracker, TrackerIndexAllocators, UsageScope, UsageScopePool},
validation::{self, validate_color_attachment_bytes_per_sample},
weak_vec::WeakVec,
@@ -130,6 +131,9 @@ pub struct Device {
pub(crate) usage_scopes: UsageScopePool,
pub(crate) last_acceleration_structure_build_command_index: AtomicU64,
pub(crate) indirect_validation: Option<crate::indirect_validation::IndirectValidation>,
// Optional so that we can late-initialize this after the queue is created.
pub(crate) timestamp_normalizer:
OnceCellOrLock<crate::timestamp_normalization::TimestampNormalizer>,
// needs to be dropped last
#[cfg(feature = "trace")]
pub(crate) trace: Mutex<Option<trace::Trace>>,
@@ -162,6 +166,9 @@ impl Drop for Device {
if let Some(indirect_validation) = self.indirect_validation.take() {
indirect_validation.dispose(self.raw.as_ref());
}
if let Some(timestamp_normalizer) = self.timestamp_normalizer.take() {
timestamp_normalizer.dispose(self.raw.as_ref());
}
unsafe {
self.raw.destroy_buffer(zero_buffer);
self.raw.destroy_fence(fence);
@@ -307,10 +314,26 @@ impl Device {
usage_scopes: Mutex::new(rank::DEVICE_USAGE_SCOPES, Default::default()),
// By starting at one, we can put the result in a NonZeroU64.
last_acceleration_structure_build_command_index: AtomicU64::new(1),
timestamp_normalizer: OnceCellOrLock::new(),
indirect_validation,
})
}
pub fn late_init_resources_with_queue(&self) -> Result<(), RequestDeviceError> {
let queue = self.get_queue().unwrap();
let timestamp_normalizer = crate::timestamp_normalization::TimestampNormalizer::new(
self,
queue.get_timestamp_period(),
)?;
self.timestamp_normalizer
.set(timestamp_normalizer)
.unwrap_or_else(|_| panic!("Called late_init_resources_with_queue twice"));
Ok(())
}
/// Returns the backend this device is using.
pub fn backend(&self) -> wgt::Backend {
self.adapter.backend()
@@ -606,6 +629,10 @@ impl Device {
usage |= wgt::BufferUses::STORAGE_READ_ONLY | wgt::BufferUses::STORAGE_READ_WRITE;
}
if desc.usage.contains(wgt::BufferUsages::QUERY_RESOLVE) {
usage |= TIMESTAMP_NORMALIZATION_BUFFER_USES;
}
if desc.mapped_at_creation {
if desc.size % wgt::COPY_BUFFER_ALIGNMENT != 0 {
return Err(resource::CreateBufferError::UnalignedSize);
@@ -645,6 +672,18 @@ impl Device {
let buffer =
unsafe { self.raw().create_buffer(&hal_desc) }.map_err(|e| self.handle_hal_error(e))?;
let timestamp_normalization_bind_group = Snatchable::new(
self.timestamp_normalizer
.get()
.unwrap()
.create_normalization_bind_group(
self,
&*buffer,
desc.label.as_deref(),
desc.usage,
)?,
);
let indirect_validation_bind_groups =
self.create_indirect_validation_bind_groups(buffer.as_ref(), desc.size, desc.usage)?;
@@ -661,6 +700,7 @@ impl Device {
label: desc.label.to_string(),
tracking_data: TrackingData::new(self.tracker_indices.buffers.clone()),
bind_groups: Mutex::new(rank::BUFFER_BIND_GROUPS, WeakVec::new()),
timestamp_normalization_bind_group,
indirect_validation_bind_groups,
};
@@ -743,6 +783,21 @@ impl Device {
hal_buffer: Box<dyn hal::DynBuffer>,
desc: &resource::BufferDescriptor,
) -> (Fallible<Buffer>, Option<resource::CreateBufferError>) {
let timestamp_normalization_bind_group = match self
.timestamp_normalizer
.get()
.unwrap()
.create_normalization_bind_group(self, &*hal_buffer, desc.label.as_deref(), desc.usage)
{
Ok(bg) => Snatchable::new(bg),
Err(e) => {
return (
Fallible::Invalid(Arc::new(desc.label.to_string())),
Some(e.into()),
)
}
};
let indirect_validation_bind_groups = match self.create_indirect_validation_bind_groups(
hal_buffer.as_ref(),
desc.size,
@@ -767,6 +822,7 @@ impl Device {
label: desc.label.to_string(),
tracking_data: TrackingData::new(self.tracker_indices.buffers.clone()),
bind_groups: Mutex::new(rank::BUFFER_BIND_GROUPS, WeakVec::new()),
timestamp_normalization_bind_group,
indirect_validation_bind_groups,
};

View File

@@ -19,7 +19,9 @@ use crate::{
lock::{rank, Mutex},
present::Presentation,
resource::ResourceType,
resource_log, DOWNLEVEL_WARNING_MESSAGE,
resource_log,
timestamp_normalization::TimestampNormalizerInitError,
DOWNLEVEL_WARNING_MESSAGE,
};
use wgt::{Backend, Backends, PowerPreference};
@@ -80,7 +82,7 @@ pub struct Instance {
/// `instance_per_backend` instead.
supported_backends: Backends,
flags: wgt::InstanceFlags,
pub flags: wgt::InstanceFlags,
}
impl Instance {
@@ -764,6 +766,7 @@ impl Adapter {
let queue = Arc::new(queue);
device.set_queue(&queue);
device.late_init_resources_with_queue()?;
Ok((device, queue))
}
@@ -835,7 +838,6 @@ pub enum GetSurfaceSupportError {
}
#[derive(Clone, Debug, Error)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
/// Error when requesting a device from the adapter
#[non_exhaustive]
pub enum RequestDeviceError {
@@ -843,6 +845,8 @@ pub enum RequestDeviceError {
Device(#[from] DeviceError),
#[error(transparent)]
LimitsExceeded(#[from] FailedLimit),
#[error("Failed to initialize Timestamp Normalizer")]
TimestampNormalizerInitFailed(#[from] TimestampNormalizerInitError),
#[error("Unsupported features were requested: {0:?}")]
UnsupportedFeature(wgt::Features),
}

View File

@@ -89,6 +89,7 @@ pub mod registry;
pub mod resource;
mod snatch;
pub mod storage;
mod timestamp_normalization;
mod track;
mod weak_vec;
// This is public for users who pre-compile shaders while still wanting to

View File

@@ -30,6 +30,7 @@ use crate::{
lock::{rank, Mutex, RwLock},
resource_log,
snatch::{SnatchGuard, Snatchable},
timestamp_normalization::TimestampNormalizationBindGroup,
track::{SharedTrackerIndexAllocator, TrackerIndex},
weak_vec::WeakVec,
Label, LabelHelpers, SubmissionIndex,
@@ -366,14 +367,20 @@ pub struct Buffer {
pub(crate) tracking_data: TrackingData,
pub(crate) map_state: Mutex<BufferMapState>,
pub(crate) bind_groups: Mutex<WeakVec<BindGroup>>,
pub(crate) timestamp_normalization_bind_group: Snatchable<TimestampNormalizationBindGroup>,
pub(crate) indirect_validation_bind_groups: Snatchable<crate::indirect_validation::BindGroups>,
}
impl Drop for Buffer {
fn drop(&mut self) {
if let Some(raw) = self.timestamp_normalization_bind_group.take() {
raw.dispose(self.device.raw());
}
if let Some(raw) = self.indirect_validation_bind_groups.take() {
raw.dispose(self.device.raw());
}
if let Some(raw) = self.raw.take() {
resource_log!("Destroy raw {}", self.error_ident());
unsafe {
@@ -708,6 +715,10 @@ impl Buffer {
}
};
let timestamp_normalization_bind_group = self
.timestamp_normalization_bind_group
.snatch(&mut snatch_guard);
let indirect_validation_bind_groups = self
.indirect_validation_bind_groups
.snatch(&mut snatch_guard);
@@ -724,6 +735,7 @@ impl Buffer {
device: Arc::clone(&self.device),
label: self.label().to_owned(),
bind_groups,
timestamp_normalization_bind_group,
indirect_validation_bind_groups,
})
};
@@ -779,6 +791,7 @@ pub struct DestroyedBuffer {
device: Arc<Device>,
label: String,
bind_groups: WeakVec<BindGroup>,
timestamp_normalization_bind_group: Option<TimestampNormalizationBindGroup>,
indirect_validation_bind_groups: Option<crate::indirect_validation::BindGroups>,
}
@@ -796,6 +809,10 @@ impl Drop for DestroyedBuffer {
)));
drop(deferred);
if let Some(raw) = self.timestamp_normalization_bind_group.take() {
raw.dispose(self.device.raw());
}
if let Some(raw) = self.indirect_validation_bind_groups.take() {
raw.dispose(self.device.raw());
}

View File

@@ -0,0 +1,142 @@
// Common routines for timestamp normalization.
//
// This is split out into its own file so that the tests in `tests` can include
// it without including the normal endpoints and interface definitions.
/// 64-bit unsigned integer type.
///
/// We cannot rely on native 64-bit integers, so we define our own 64-bit
/// integer type as two 32-bit integers.
struct Uint64 {
/// Least significant word.
low: u32,
/// Most significant word.
high: u32,
}
/// 96-bit unsigned integer type.
struct Uint96 {
/// Least significant word.
low: u32,
/// Middle word.
mid: u32,
/// Most significant word.
high: u32,
}
/// Truncates a 96-bit number to a 64-bit number by discarding the upper 32 bits.
fn truncate_u96_to_u64(x: Uint96) -> Uint64 {
return Uint64(
x.low,
x.mid,
);
}
/// Returns the lower 16 bits of a 32-bit integer.
fn low(a: u32) -> u32 {
return a & 0xFFFF;
}
/// Returns the upper 16 bits of a 32-bit integer.
fn high(a: u32) -> u32 {
return a >> 16;
}
/// Combines two 16bit words into a single 32bit word.
/// `w1` is the upper 16 bits and `w0` is the lower 16 bits.
///
/// The high 16 bits of each argument are discarded.
fn u32_from_u16s(w1: u32, w0: u32) -> u32 {
return low(w1) << 16 | low(w0);
}
// Multiplies a 64-bit number by a 32-bit number and outputs a 96-bit result.
//
// The number of digits (bits) needed to represent the result of a multiplication
// is the sum of the number of input digits (bits). Since we are multiplying a
// 64-bit number by a 32-bit number, we need 96 bits to represent the result.
fn u64_mul_u32(a: Uint64, b: u32) -> Uint96 {
// Does not use any 64-bit operations and we don't have access to `mul(u32, u32) -> u64`
// operations, so we operate entirely on `mul(u16, u16) -> u32`.
// This implements standard "long multiplication" algorithm using 16-bit words.
// Each element in this diagram is a 16-bit word.
//
// a3 a2 a1 a0
// * b1 b0
// ----------------------------
// i0 = p00 p00
// i1 = p10 p10
// i2 = p20 p20
// i3 = p30 p30
// i4 = p01 p01
// i5 = p11 p11
// i6 = p21 p21
// i7 = p31 p31
// ----------------------------
// r6 r5 r4 r3 r2 r1 r0
// Decompose the 64-bit number into four 16-bit words.
let a0 = low(a.low);
let a1 = high(a.low);
let a2 = low(a.high);
let a3 = high(a.high);
// Decompose the 32-bit number into two 16-bit words.
let b0 = low(b);
let b1 = high(b);
// Each line represents one row in the diagram above.
let i0 = a0 * b0;
let i1 = a1 * b0;
let i2 = a2 * b0;
let i3 = a3 * b0;
let i4 = a0 * b1;
let i5 = a1 * b1;
let i6 = a2 * b1;
let i7 = a3 * b1;
// Each line represents one column in the diagram above.
//
// The high 16 bits of each column are the carry to the next column.
let r0 = low(i0);
let r1 = high(i0) + low(i1) + low(i4) + high(r0);
let r2 = high(i1) + low(i2) + high(i4) + low(i5) + high(r1);
let r3 = high(i2) + low(i3) + high(i5) + low(i6) + high(r2);
let r4 = high(i3) + high(i6) + low(i7) + high(r3);
let r5 = high(i7) + high(r4);
// The r5 carry will always be zero.
let out0 = u32_from_u16s(r1, r0);
let out1 = u32_from_u16s(r3, r2);
let out2 = u32_from_u16s(r5, r4);
return Uint96(out0, out1, out2);
}
// Shifts a 96-bit number right by a given number of bits.
//
// The shift is in the range [0, 32].
fn shift_right_96(x: Uint96, shift: u32) -> Uint96 {
// Shift wraps around at 32, which breaks the algorithm when
// either shift is 32 or inv_shift is 32.
if (shift == 0) {
return x;
}
if (shift == 32) {
return Uint96(x.mid, x.high, 0);
}
let inv_shift = 32 - shift;
let carry2 = x.high << inv_shift;
let carry1 = x.mid << inv_shift;
var out: Uint96;
out.high = x.high >> shift;
out.mid = (x.mid >> shift) | carry2;
out.low = (x.low >> shift) | carry1;
return out;
}

View File

@@ -0,0 +1,403 @@
//! Utility for normalizing GPU timestamp queries to have a consistent
//! 1GHz period. This uses a compute shader to do the normalization,
//! so the timestamps exist in their correct format on the GPU, as
//! is required by the WebGPU specification.
//!
//! ## Algorithm
//!
//! The fundamental operation is multiplying a u64 timestamp by an f32
//! value. We have neither f64s nor u64s in shaders, so we need to do
//! something more complicated.
//!
//! We first decompose the f32 into a u32 fraction where the denominator
//! is a power of two. We do the computation with f64 for ease of computation,
//! as those can store u32s losslessly.
//!
//! Because the denominator is a power of two, this means the shader can evaluate
//! this divide by using a shift. Additionally, we always choose the largest denominator
//! we can, so that the fraction is as precise as possible.
//!
//! To evaluate this function, we have two helper operations (both in common.wgsl).
//!
//! 1. `u64_mul_u32` multiplies a u64 by a u32 and returns a u96.
//! 2. `shift_right_u96` shifts a u96 right by a given amount, returning a u96.
//!
//! See their implementations for more details.
//!
//! We then multiply the timestamp by the numerator, and shift it right by the
//! denominator. This gives us the normalized timestamp.
use alloc::{boxed::Box, string::String, string::ToString, sync::Arc};
use hashbrown::HashMap;
use wgt::PushConstantRange;
use crate::{
device::{Device, DeviceError},
pipeline::{CreateComputePipelineError, CreateShaderModuleError},
resource::Buffer,
snatch::SnatchGuard,
track::BufferTracker,
};
pub const TIMESTAMP_NORMALIZATION_BUFFER_USES: wgt::BufferUses =
wgt::BufferUses::STORAGE_READ_WRITE;
struct InternalState {
temporary_bind_group_layout: Box<dyn hal::DynBindGroupLayout>,
pipeline_layout: Box<dyn hal::DynPipelineLayout>,
pipeline: Box<dyn hal::DynComputePipeline>,
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum TimestampNormalizerInitError {
#[error("Failed to initialize bind group layout")]
BindGroupLayout(#[source] DeviceError),
#[cfg(feature = "wgsl")]
#[error("Failed to parse shader")]
ParseWgsl(#[source] naga::error::ShaderError<naga::front::wgsl::ParseError>),
#[error("Failed to validate shader module")]
ValidateWgsl(#[source] naga::error::ShaderError<naga::WithSpan<naga::valid::ValidationError>>),
#[error("Failed to create shader module")]
CreateShaderModule(#[from] CreateShaderModuleError),
#[error("Failed to create pipeline layout")]
PipelineLayout(#[source] DeviceError),
#[error("Failed to create compute pipeline")]
ComputePipeline(#[from] CreateComputePipelineError),
}
/// Normalizes GPU timestamps to have a consistent 1GHz period.
/// See module documentation for more information.
pub struct TimestampNormalizer {
state: Option<InternalState>,
}
impl TimestampNormalizer {
/// Creates a new timestamp normalizer.
///
/// If the device cannot support automatic timestamp normalization,
/// this will return a normalizer that does nothing.
///
/// # Errors
///
/// If any resources are invalid, this will return an error.
pub fn new(
device: &Device,
timestamp_period: f32,
) -> Result<Self, TimestampNormalizerInitError> {
unsafe {
if !device
.instance_flags
.contains(wgt::InstanceFlags::AUTOMATIC_TIMESTAMP_NORMALIZATION)
{
return Ok(Self { state: None });
}
if !device
.downlevel
.flags
.contains(wgt::DownlevelFlags::COMPUTE_SHADERS)
{
log::error!("Automatic timestamp normalization was requested, but compute shaders are not supported.");
return Ok(Self { state: None });
}
if timestamp_period == 1.0 {
// If the period is 1, we don't need to do anything to them.
return Ok(Self { state: None });
}
let temporary_bind_group_layout = device
.raw()
.create_bind_group_layout(&hal::BindGroupLayoutDescriptor {
label: Some("Timestamp Normalization Bind Group Layout"),
flags: hal::BindGroupLayoutFlags::empty(),
entries: &[wgt::BindGroupLayoutEntry {
binding: 0,
visibility: wgt::ShaderStages::COMPUTE,
ty: wgt::BindingType::Buffer {
ty: wgt::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}],
})
.map_err(|e| {
TimestampNormalizerInitError::BindGroupLayout(device.handle_hal_error(e))
})?;
let common_src = include_str!("common.wgsl");
let src = include_str!("timestamp_normalization.wgsl");
let preprocessed_src = alloc::format!("{common_src}\n{src}");
#[cfg(feature = "wgsl")]
let module = naga::front::wgsl::parse_str(&preprocessed_src).map_err(|inner| {
TimestampNormalizerInitError::ParseWgsl(naga::error::ShaderError {
source: preprocessed_src.clone(),
label: None,
inner: Box::new(inner),
})
})?;
#[cfg(not(feature = "wgsl"))]
#[allow(clippy::diverging_sub_expression)]
let module =
panic!("Timestamp normalization requires the wgsl feature flag to be enabled!");
let info = crate::device::create_validator(
wgt::Features::PUSH_CONSTANTS,
wgt::DownlevelFlags::empty(),
naga::valid::ValidationFlags::all(),
)
.validate(&module)
.map_err(|inner| {
TimestampNormalizerInitError::ValidateWgsl(naga::error::ShaderError {
source: preprocessed_src.clone(),
label: None,
inner: Box::new(inner),
})
})?;
let hal_shader = hal::ShaderInput::Naga(hal::NagaShader {
module: alloc::borrow::Cow::Owned(module),
info,
debug_source: None,
});
let hal_desc = hal::ShaderModuleDescriptor {
label: None,
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
};
let module = device
.raw()
.create_shader_module(&hal_desc, hal_shader)
.map_err(|error| match error {
hal::ShaderError::Device(error) => {
CreateShaderModuleError::Device(device.handle_hal_error(error))
}
hal::ShaderError::Compilation(ref msg) => {
log::error!("Shader error: {}", msg);
CreateShaderModuleError::Generation
}
})?;
let pipeline_layout = device
.raw()
.create_pipeline_layout(&hal::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[temporary_bind_group_layout.as_ref()],
push_constant_ranges: &[PushConstantRange {
stages: wgt::ShaderStages::COMPUTE,
range: 0..8,
}],
flags: hal::PipelineLayoutFlags::empty(),
})
.map_err(|e| {
TimestampNormalizerInitError::PipelineLayout(device.handle_hal_error(e))
})?;
let (multiplier, shift) = compute_timestamp_period(timestamp_period);
let mut constants = HashMap::with_capacity(2);
constants.insert(String::from("TIMESTAMP_PERIOD_MULTIPLY"), multiplier as f64);
constants.insert(String::from("TIMESTAMP_PERIOD_SHIFT"), shift as f64);
let pipeline_desc = hal::ComputePipelineDescriptor {
label: None,
layout: pipeline_layout.as_ref(),
stage: hal::ProgrammableStage {
module: module.as_ref(),
entry_point: "main",
constants: &constants,
zero_initialize_workgroup_memory: false,
},
cache: None,
};
let pipeline = device
.raw()
.create_compute_pipeline(&pipeline_desc)
.map_err(|err| match err {
hal::PipelineError::Device(error) => {
CreateComputePipelineError::Device(device.handle_hal_error(error))
}
hal::PipelineError::Linkage(_stages, msg) => {
CreateComputePipelineError::Internal(msg)
}
hal::PipelineError::EntryPoint(_stage) => CreateComputePipelineError::Internal(
crate::device::ENTRYPOINT_FAILURE_ERROR.to_string(),
),
hal::PipelineError::PipelineConstants(_, error) => {
CreateComputePipelineError::PipelineConstants(error)
}
})?;
Ok(Self {
state: Some(InternalState {
temporary_bind_group_layout,
pipeline_layout,
pipeline,
}),
})
}
}
pub fn create_normalization_bind_group(
&self,
device: &Device,
buffer: &dyn hal::DynBuffer,
buffer_label: Option<&str>,
buffer_usages: wgt::BufferUsages,
) -> Result<TimestampNormalizationBindGroup, DeviceError> {
unsafe {
let Some(ref state) = &self.state else {
return Ok(TimestampNormalizationBindGroup { raw: None });
};
if !buffer_usages.contains(wgt::BufferUsages::QUERY_RESOLVE) {
return Ok(TimestampNormalizationBindGroup { raw: None });
}
let bg_label_alloc;
let label = match buffer_label {
Some(label) => {
bg_label_alloc =
alloc::format!("Timestamp normalization bind group ({})", label);
&*bg_label_alloc
}
None => "Timestamp normalization bind group",
};
let bg = device
.raw()
.create_bind_group(&hal::BindGroupDescriptor {
label: Some(label),
layout: &*state.temporary_bind_group_layout,
buffers: &[hal::BufferBinding {
buffer,
offset: 0,
size: None,
}],
samplers: &[],
textures: &[],
acceleration_structures: &[],
entries: &[hal::BindGroupEntry {
binding: 0,
resource_index: 0,
count: 1,
}],
})
.map_err(|e| device.handle_hal_error(e))?;
Ok(TimestampNormalizationBindGroup { raw: Some(bg) })
}
}
pub fn normalize(
&self,
snatch_guard: &SnatchGuard<'_>,
encoder: &mut dyn hal::DynCommandEncoder,
tracker: &mut BufferTracker,
bind_group: &TimestampNormalizationBindGroup,
buffer: &Arc<Buffer>,
range: core::ops::Range<u32>,
) {
let Some(ref state) = &self.state else {
return;
};
let Some(bind_group) = bind_group.raw.as_deref() else {
return;
};
let pending_barrier = tracker.set_single(buffer, wgt::BufferUses::STORAGE_READ_WRITE);
let barrier = pending_barrier.map(|pending| pending.into_hal(buffer, snatch_guard));
let total_timestamps = range.len() as u32;
let needed_workgroups = total_timestamps.div_ceil(64);
unsafe {
encoder.transition_buffers(barrier.as_slice());
encoder.begin_compute_pass(&hal::ComputePassDescriptor {
label: Some("Timestamp normalization pass"),
timestamp_writes: None,
});
encoder.set_compute_pipeline(&*state.pipeline);
encoder.set_bind_group(&*state.pipeline_layout, 0, Some(bind_group), &[]);
encoder.set_push_constants(
&*state.pipeline_layout,
wgt::ShaderStages::COMPUTE,
0,
&[range.start, range.len() as u32],
);
encoder.dispatch([needed_workgroups, 1, 1]);
encoder.end_compute_pass();
}
}
pub fn dispose(self, device: &dyn hal::DynDevice) {
unsafe {
let Some(state) = self.state else {
return;
};
device.destroy_compute_pipeline(state.pipeline);
device.destroy_pipeline_layout(state.pipeline_layout);
device.destroy_bind_group_layout(state.temporary_bind_group_layout);
}
}
pub fn enabled(&self) -> bool {
self.state.is_some()
}
}
#[derive(Debug)]
pub struct TimestampNormalizationBindGroup {
raw: Option<Box<dyn hal::DynBindGroup>>,
}
impl TimestampNormalizationBindGroup {
pub fn dispose(self, device: &dyn hal::DynDevice) {
unsafe {
if let Some(raw) = self.raw {
device.destroy_bind_group(raw);
}
}
}
}
fn compute_timestamp_period(input: f32) -> (u32, u32) {
let pow2 = input.log2().ceil() as i32;
let clamped_pow2 = pow2.clamp(-32, 32).unsigned_abs();
let shift = 32 - clamped_pow2;
let denominator = (1u64 << shift) as f64;
// float -> int conversions are defined to saturate.
let multiplier = (input as f64 * denominator).round() as u32;
(multiplier, shift)
}
#[cfg(test)]
mod tests {
use core::f64;
fn assert_timestamp_case(input: f32) {
let (multiplier, shift) = super::compute_timestamp_period(input);
let output = multiplier as f64 / (1u64 << shift) as f64;
assert!((input as f64 - output).abs() < 0.0000001);
}
#[test]
fn compute_timestamp_period() {
assert_timestamp_case(0.01);
assert_timestamp_case(0.5);
assert_timestamp_case(1.0);
assert_timestamp_case(2.0);
assert_timestamp_case(2.7);
assert_timestamp_case(1000.7);
}
}

View File

@@ -0,0 +1,41 @@
// Must have "common.wgsl" preprocessed before this file's contents.
//
// To compile this locally, you can run:
// ```
// cat common.wgsl timestamp_normalization.wgsl | cargo run -p naga-cli -- --stdin-file-path timestamp_normalization.wgsl
// ```
// For an explanation of the timestamp normalization process, see
// the `mod.rs` file in this folder.
// These is the timestamp period turned into a fraction
// with an integer numerator and denominator. The denominator
// is a power of two, so the division can be done with a shift.
override TIMESTAMP_PERIOD_MULTIPLY: u32 = 1;
override TIMESTAMP_PERIOD_SHIFT: u32 = 0;
@group(0) @binding(0)
var<storage, read_write> timestamps: array<Uint64>;
struct PushConstants {
timestamp_offset: u32,
timestamp_count: u32,
}
var<push_constant> pc: PushConstants;
@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) id: vec3u) {
if id.x >= pc.timestamp_count {
return;
}
let index = id.x + pc.timestamp_offset;
let input_value = timestamps[index];
let tmp1 = u64_mul_u32(input_value, TIMESTAMP_PERIOD_MULTIPLY);
let tmp2 = shift_right_96(tmp1, TIMESTAMP_PERIOD_SHIFT);
timestamps[index] = truncate_u96_to_u64(tmp2);
}

View File

@@ -103,6 +103,18 @@ bitflags::bitflags! {
///
/// When `Self::from_env()` is used takes value from `WGPU_VALIDATION_INDIRECT_CALL` environment variable.
const VALIDATION_INDIRECT_CALL = 1 << 5;
/// Enable automatic timestamp normalization. This means that in [`CommandEncoder::resolve_query_set`][rqs],
/// the timestamps will automatically be normalized to be in nanoseconds instead of the raw timestamp values.
///
/// This is disabled by default because it introduces a compute shader into the resolution of query sets.
///
/// This can be useful for users that need to read timestamps on the gpu, as the normalization
/// can be a hassle to do manually. When this is enabled, the timestamp period returned by the queue
/// will always be `1.0`.
///
/// [rqs]: ../wgpu/struct.CommandEncoder.html#method.resolve_query_set
const AUTOMATIC_TIMESTAMP_NORMALIZATION = 1 << 6;
}
}