Move multithreaded_compute.rs into hello-compute tests (#223)

Setup hello-compute tests to run during `cargo test`
This commit is contained in:
Lucas Kent
2020-03-30 10:55:39 +11:00
committed by GitHub
parent d91b78bdfb
commit d08f837624
3 changed files with 31 additions and 124 deletions

View File

@@ -61,4 +61,9 @@ futures = "0.3"
#[patch."https://github.com/gfx-rs/wgpu"]
#wgc = { version = "0.1.0", package = "wgpu-core", path = "../wgpu/wgpu-core" }
#wgt = { version = "0.1.0", package = "wgpu-types", path = "../wgpu/wgpu-types" }
#wgn = { version = "0.4.0", package = "wgpu-native", path = "../wgpu/wgpu-native" }
#wgn = { version = "0.4.0", package = "wgpu-native", path = "../wgpu/wgpu-native" }
[[example]]
name="hello-compute"
path="examples/hello-compute/main.rs"
test = true

View File

@@ -141,7 +141,31 @@ mod tests {
futures::executor::block_on(assert_execute_gpu(input, vec!(5, 15, 6, 19)));
}
async fn assert_execute_gpu(input: Vec<u32>, expected: Vec<u32>){
#[test]
fn test_multithreaded_compute() {
use std::sync::mpsc;
use std::thread;
use std::time::Duration;
let thread_count = 8;
let (tx, rx) = mpsc::channel();
for _ in 0 .. thread_count {
let tx = tx.clone();
thread::spawn(move || {
let input = vec![100, 100, 100];
futures::executor::block_on(assert_execute_gpu(input, vec!(25, 25, 25)));
tx.send(true).unwrap();
});
}
for _ in 0 .. thread_count {
rx.recv_timeout(Duration::from_secs(10))
.expect("A thread never completed.");
}
}
async fn assert_execute_gpu(input: Vec<u32>, expected: Vec<u32>) {
assert_eq!(execute_gpu(input).await, expected);
}
}

View File

@@ -1,122 +0,0 @@
#[cfg(any(feature = "vulkan", feature = "metal", feature = "dx12"))]
use std::convert::TryInto as _;
#[test]
#[cfg(any(feature = "vulkan", feature = "metal", feature = "dx12"))]
fn multithreaded_compute() {
use std::sync::mpsc;
use std::thread;
use std::time::Duration;
let thread_count = 8;
let (tx, rx) = mpsc::channel();
for _ in 0 .. thread_count {
let tx = tx.clone();
thread::spawn(move || {
let numbers = vec![100, 100, 100];
let slice_size = numbers.len() * std::mem::size_of::<u32>();
let size = slice_size as wgpu::BufferAddress;
let adapter = wgpu::Adapter::request(
&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::Default,
},
wgpu::BackendBit::PRIMARY,
)
.unwrap();
let (device, mut queue) = adapter.request_device(&wgpu::DeviceDescriptor {
extensions: wgpu::Extensions {
anisotropic_filtering: false,
},
limits: wgpu::Limits::default(),
});
let cs = include_bytes!("../examples/hello-compute/shader.comp.spv");
let cs_module = device
.create_shader_module(&wgpu::read_spirv(std::io::Cursor::new(&cs[..])).unwrap());
let staging_buffer = device.create_buffer_with_data(
numbers.as_slice(),
wgpu::BufferUsage::MAP_READ
| wgpu::BufferUsage::COPY_DST
| wgpu::BufferUsage::COPY_SRC,
);
let storage_buffer = device.create_buffer(&wgpu::BufferDescriptor {
size,
usage: wgpu::BufferUsage::STORAGE
| wgpu::BufferUsage::COPY_DST
| wgpu::BufferUsage::COPY_SRC,
});
let bind_group_layout =
device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
bindings: &[wgpu::BindGroupLayoutBinding {
binding: 0,
visibility: wgpu::ShaderStage::COMPUTE,
ty: wgpu::BindingType::StorageBuffer {
dynamic: false,
readonly: false,
},
}],
});
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &bind_group_layout,
bindings: &[wgpu::Binding {
binding: 0,
resource: wgpu::BindingResource::Buffer {
buffer: &storage_buffer,
range: 0 .. size,
},
}],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
bind_group_layouts: &[&bind_group_layout],
});
let compute_pipeline =
device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
layout: &pipeline_layout,
compute_stage: wgpu::ProgrammableStageDescriptor {
module: &cs_module,
entry_point: "main",
},
});
let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { todo: 0 });
encoder.copy_buffer_to_buffer(&staging_buffer, 0, &storage_buffer, 0, size);
{
let mut cpass = encoder.begin_compute_pass();
cpass.set_pipeline(&compute_pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.dispatch(numbers.len() as u32, 1, 1);
}
encoder.copy_buffer_to_buffer(&storage_buffer, 0, &staging_buffer, 0, size);
queue.submit(&[encoder.finish()]);
// FIXME: Align and use `LayoutVerified`
staging_buffer.map_read_async(0, slice_size, |result| {
let result_data: Box<[u32]> = result
.unwrap()
.data
.chunks_exact(std::mem::size_of::<u32>())
.map(|c| u32::from_ne_bytes(c.try_into().unwrap()))
.collect();
assert_eq!(&*result_data, &[25, 25, 25]);
});
tx.send(true).unwrap();
});
}
for _ in 0 .. thread_count {
rx.recv_timeout(Duration::from_secs(10))
.expect("A thread never completed.");
}
}