Files
wgpu/examples/hello-compute/main.rs
2019-12-31 16:24:49 +01:00

110 lines
3.5 KiB
Rust

use std::{convert::TryInto as _, str::FromStr};
use zerocopy::AsBytes as _;
async fn run() {
env_logger::init();
// For now this just panics if you didn't pass numbers. Could add proper error handling.
if std::env::args().len() == 1 {
panic!("You must pass a list of positive integers!")
}
let numbers: Vec<u32> = std::env::args()
.skip(1)
.map(|s| u32::from_str(&s).expect("You must pass a list of positive integers!"))
.collect();
let slice_size = numbers.len() * std::mem::size_of::<u32>();
let size = slice_size as wgpu::BufferAddress;
let adapter = wgpu::Adapter::request(
&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::Default,
},
wgpu::BackendBit::PRIMARY,
)
.unwrap();
let (device, queue) = adapter.request_device(&wgpu::DeviceDescriptor {
extensions: wgpu::Extensions {
anisotropic_filtering: false,
},
limits: wgpu::Limits::default(),
});
let cs = include_bytes!("shader.comp.spv");
let cs_module =
device.create_shader_module(&wgpu::read_spirv(std::io::Cursor::new(&cs[..])).unwrap());
let staging_buffer = device.create_buffer_with_data(
numbers.as_slice().as_bytes(),
wgpu::BufferUsage::MAP_READ | wgpu::BufferUsage::COPY_DST | wgpu::BufferUsage::COPY_SRC,
);
let storage_buffer = device.create_buffer(&wgpu::BufferDescriptor {
size,
usage: wgpu::BufferUsage::STORAGE
| wgpu::BufferUsage::COPY_DST
| wgpu::BufferUsage::COPY_SRC,
});
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
bindings: &[wgpu::BindGroupLayoutBinding {
binding: 0,
visibility: wgpu::ShaderStage::COMPUTE,
ty: wgpu::BindingType::StorageBuffer {
dynamic: false,
readonly: false,
},
}],
});
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &bind_group_layout,
bindings: &[wgpu::Binding {
binding: 0,
resource: wgpu::BindingResource::Buffer {
buffer: &storage_buffer,
range: 0 .. size,
},
}],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
bind_group_layouts: &[&bind_group_layout],
});
let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
layout: &pipeline_layout,
compute_stage: wgpu::ProgrammableStageDescriptor {
module: &cs_module,
entry_point: "main",
},
});
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor { todo: 0 });
encoder.copy_buffer_to_buffer(&staging_buffer, 0, &storage_buffer, 0, size);
{
let mut cpass = encoder.begin_compute_pass();
cpass.set_pipeline(&compute_pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.dispatch(numbers.len() as u32, 1, 1);
}
encoder.copy_buffer_to_buffer(&storage_buffer, 0, &staging_buffer, 0, size);
queue.submit(&[encoder.finish()]);
if let Ok(mapping) = staging_buffer.map_read(0u64, size).await {
let times : Box<[u32]> = mapping
.as_slice()
.chunks_exact(4)
.map(|b| u32::from_ne_bytes(b.try_into().unwrap()))
.collect();
println!("Times: {:?}", times);
}
}
fn main() {
futures::executor::block_on(run());
}