From 5f0df67dcc6ff3760328ad98f669c5c7c4f2012f Mon Sep 17 00:00:00 2001 From: Tristam MacDonald Date: Sun, 10 Feb 2019 15:57:04 -0800 Subject: [PATCH] Initial compute pipeline support --- examples/Cargo.toml | 4 ++ examples/data/collatz.comp | 31 ++++++++ examples/data/collatz.comp.spv | Bin 0 -> 1576 bytes examples/hello_compute_rust/main.rs | 107 ++++++++++++++++++++++++++++ wgpu-native/src/device.rs | 59 +++++++++++++-- wgpu-native/src/pipeline.rs | 6 +- wgpu-rs/src/lib.rs | 63 +++++++++++----- 7 files changed, 243 insertions(+), 27 deletions(-) create mode 100644 examples/data/collatz.comp create mode 100644 examples/data/collatz.comp.spv create mode 100644 examples/hello_compute_rust/main.rs diff --git a/examples/Cargo.toml b/examples/Cargo.toml index a0a858040d..b4b5fe6763 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -11,6 +11,10 @@ publish = false name = "hello_triangle" path = "hello_triangle_rust/main.rs" +[[bin]] +name = "hello_compute" +path = "hello_compute_rust/main.rs" + [features] default = [] remote = ["wgpu-native/remote"] diff --git a/examples/data/collatz.comp b/examples/data/collatz.comp new file mode 100644 index 0000000000..d1ab4cfd2e --- /dev/null +++ b/examples/data/collatz.comp @@ -0,0 +1,31 @@ +#version 450 +layout(local_size_x = 1) in; + +layout(set = 0, binding = 0) buffer PrimeIndices { + uint[] indices; +}; // this is used as both input and output for convenience + +// The Collatz Conjecture states that for any integer n: +// If n is even, n = n/2 +// If n is odd, n = 3n+1 +// And repeat this process for each new n, you will always eventually reach 1. +// Though the conjecture has not been proven, no counterexample has ever been found. +// This function returns how many times this recurrence needs to be applied to reach 1. +uint collatz_iterations(uint n) { + uint i = 0; + while(n != 1) { + if (mod(n, 2) == 0) { + n = n / 2; + } + else { + n = (3 * n) + 1; + } + i++; + } + return i; +} + +void main() { + uint index = gl_GlobalInvocationID.x; + indices[index] = collatz_iterations(indices[index]); +} \ No newline at end of file diff --git a/examples/data/collatz.comp.spv b/examples/data/collatz.comp.spv new file mode 100644 index 0000000000000000000000000000000000000000..972544e0f9087b83fd9fb907d1c896aecf817a1d GIT binary patch literal 1576 zcmYk6*-leY6o$8yQItU@C$I$v6oaCSA_!_yNw_fa0Z4_$G$}2CYK&K281H-%Z+!wE z$|=$KeWzz7>||yAYx?)v>#VlWG1HUMnbebd)3ek*!>I#LO8u!j)8*Y)yYnZlgZag! z`(_NJLY8RGAUR#=PUNssYZ!Qd4B^)|jF=Zvptb}1JB}_gPS-zHp!zy1Cmu{D!a#hjGPd(S}#x3-vFEQDX+f?_IIQIWA3{J*6-VcQ{QP$xr1{S?EIsC4y@1LTYMXR Y*83goNv$2<%DD#LU9f!rh>P8g|6_G!ZvX%Q literal 0 HcmV?d00001 diff --git a/examples/hello_compute_rust/main.rs b/examples/hello_compute_rust/main.rs new file mode 100644 index 0000000000..ed5ef05f4c --- /dev/null +++ b/examples/hello_compute_rust/main.rs @@ -0,0 +1,107 @@ +extern crate env_logger; +extern crate wgpu; +extern crate wgpu_native; + +use std::str::FromStr; + +// TODO: deduplicate this with the copy in gfx-examples/framework +pub fn cast_slice(data: &[T]) -> &[u8] { + use std::mem::size_of; + use std::slice::from_raw_parts; + + unsafe { from_raw_parts(data.as_ptr() as *const u8, data.len() * size_of::()) } +} + +fn main() { + env_logger::init(); + + // For now this just panics if you didn't pass numbers. Could add proper error handling. + if std::env::args().len() == 1 { + panic!("You must pass a list of positive integers!") + } + let numbers: Vec = std::env::args() + .skip(1) + .map(|s| u32::from_str(&s).expect("You must pass a list of positive integers!")) + .collect(); + + let size = (numbers.len() * std::mem::size_of::()) as u32; + + let instance = wgpu::Instance::new(); + let adapter = instance.get_adapter(&wgpu::AdapterDescriptor { + power_preference: wgpu::PowerPreference::LowPower, + }); + let mut device = adapter.create_device(&wgpu::DeviceDescriptor { + extensions: wgpu::Extensions { + anisotropic_filtering: false, + }, + }); + + let cs_bytes = include_bytes!("./../data/collatz.comp.spv"); + let cs_module = device.create_shader_module(cs_bytes); + + let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor { + size, + usage: wgpu::BufferUsageFlags::MAP_READ + | wgpu::BufferUsageFlags::TRANSFER_DST + | wgpu::BufferUsageFlags::TRANSFER_SRC, + }); + staging_buffer.set_sub_data(0, cast_slice(&numbers)); + + let storage_buffer = device.create_buffer(&wgpu::BufferDescriptor { + size: (numbers.len() * std::mem::size_of::()) as u32, + usage: wgpu::BufferUsageFlags::STORAGE + | wgpu::BufferUsageFlags::TRANSFER_DST + | wgpu::BufferUsageFlags::TRANSFER_SRC, + }); + + let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + bindings: &[wgpu::BindGroupLayoutBinding { + binding: 0, + visibility: wgpu::ShaderStageFlags::COMPUTE, + ty: wgpu::BindingType::StorageBuffer, + }], + }); + + let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { + layout: &bind_group_layout, + bindings: &[wgpu::Binding { + binding: 0, + resource: wgpu::BindingResource::Buffer { + buffer: &storage_buffer, + range: 0..(numbers.len() as u32), + }, + }], + }); + + let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + bind_group_layouts: &[&bind_group_layout], + }); + + let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + layout: &pipeline_layout, + compute_stage: wgpu::PipelineStageDescriptor { + module: &cs_module, + stage: wgpu::ShaderStage::Compute, + entry_point: "main", + }, + }); + + let mut cmd_buf = device.create_command_buffer(&wgpu::CommandBufferDescriptor { todo: 0 }); + { + cmd_buf.copy_buffer_tobuffer(&staging_buffer, 0, &storage_buffer, 0, size); + } + { + let mut cpass = cmd_buf.begin_compute_pass(); + cpass.set_pipeline(&compute_pipeline); + cpass.set_bind_group(0, &bind_group); + cpass.dispatch(numbers.len() as u32, 1, 1); + cpass.end_pass(); + } + { + cmd_buf.copy_buffer_tobuffer(&storage_buffer, 0, &staging_buffer, 0, size); + } + + // TODO: read the results back out of the staging buffer + + device.get_queue().submit(&[cmd_buf]); +} diff --git a/wgpu-native/src/device.rs b/wgpu-native/src/device.rs index fff23b1c29..54ff059e59 100644 --- a/wgpu-native/src/device.rs +++ b/wgpu-native/src/device.rs @@ -2,12 +2,10 @@ use crate::{back, binding_model, command, conv, pipeline, resource, swap_chain}; use crate::registry::{HUB, Items}; use crate::track::{BufferTracker, TextureTracker, TrackPermit}; use crate::{ - LifeGuard, RefCount, Stored, SubmissionIndex, WeaklyStored, - BindGroupLayoutId, BindGroupId, - BlendStateId, BufferId, CommandBufferId, DepthStencilStateId, - AdapterId, DeviceId, PipelineLayoutId, QueueId, RenderPipelineId, ShaderModuleId, - SamplerId, TextureId, TextureViewId, - SurfaceId, SwapChainId, + AdapterId, BindGroupId, BindGroupLayoutId, BlendStateId, BufferId, CommandBufferId, + ComputePipelineId, DepthStencilStateId, DeviceId, LifeGuard, PipelineLayoutId, QueueId, + RefCount, RenderPipelineId, SamplerId, ShaderModuleId, Stored, SubmissionIndex, SurfaceId, + SwapChainId, TextureId, TextureViewId, WeaklyStored, }; use hal::command::RawCommandBuffer; @@ -1143,6 +1141,55 @@ pub extern "C" fn wgpu_device_create_render_pipeline( }) } +#[no_mangle] +pub extern "C" fn wgpu_device_create_compute_pipeline( + device_id: DeviceId, + desc: &pipeline::ComputePipelineDescriptor, +) -> ComputePipelineId { + let device_guard = HUB.devices.read(); + let device = device_guard.get(device_id); + let pipeline_layout_guard = HUB.pipeline_layouts.read(); + let layout = &pipeline_layout_guard.get(desc.layout).raw; + let pipeline_stage = &desc.compute_stage; + let shader_module_guard = HUB.shader_modules.read(); + + assert!(pipeline_stage.stage == pipeline::ShaderStage::Compute); // TODO + + let shader = hal::pso::EntryPoint:: { + entry: unsafe { ffi::CStr::from_ptr(pipeline_stage.entry_point) } + .to_str() + .to_owned() + .unwrap(), // TODO + module: &shader_module_guard.get(pipeline_stage.module).raw, + specialization: hal::pso::Specialization { + // TODO + constants: &[], + data: &[], + }, + }; + + // TODO + let flags = hal::pso::PipelineCreationFlags::empty(); + // TODO + let parent = hal::pso::BasePipeline::None; + + let pipeline_desc = hal::pso::ComputePipelineDesc { + shader, + layout, + flags, + parent, + }; + + let pipeline = unsafe { device.raw.create_compute_pipeline(&pipeline_desc, None) }.unwrap(); + + HUB.compute_pipelines + .write() + .register(pipeline::ComputePipeline { + raw: pipeline, + layout_id: WeaklyStored(desc.layout), + }) +} + #[no_mangle] pub extern "C" fn wgpu_device_create_swap_chain( device_id: DeviceId, diff --git a/wgpu-native/src/pipeline.rs b/wgpu-native/src/pipeline.rs index dde4cc93d8..da6e6e55ef 100644 --- a/wgpu-native/src/pipeline.rs +++ b/wgpu-native/src/pipeline.rs @@ -1,12 +1,10 @@ use crate::resource; use crate::{ - ByteArray, WeaklyStored, - BlendStateId, DepthStencilStateId, PipelineLayoutId, ShaderModuleId, + BlendStateId, ByteArray, DepthStencilStateId, PipelineLayoutId, ShaderModuleId, WeaklyStored, }; use bitflags::bitflags; - pub type ShaderAttributeIndex = u32; #[repr(C)] @@ -208,7 +206,7 @@ pub struct PipelineStageDescriptor { #[repr(C)] pub struct ComputePipelineDescriptor { pub layout: PipelineLayoutId, - pub stages: *const PipelineStageDescriptor, + pub compute_stage: PipelineStageDescriptor, } pub(crate) struct ComputePipeline { diff --git a/wgpu-rs/src/lib.rs b/wgpu-rs/src/lib.rs index 9905bf8e63..ae9adf2fd8 100644 --- a/wgpu-rs/src/lib.rs +++ b/wgpu-rs/src/lib.rs @@ -9,15 +9,15 @@ use std::ops::Range; use std::ptr; pub use wgn::{ - AdapterDescriptor, Attachment, BindGroupLayoutBinding, BindingType, BlendStateDescriptor, - BufferDescriptor, BufferUsageFlags, - IndexFormat, VertexFormat, InputStepMode, ShaderAttributeIndex, VertexAttributeDescriptor, - Color, ColorWriteFlags, CommandBufferDescriptor, DepthStencilStateDescriptor, - DeviceDescriptor, Extensions, Extent3d, LoadOp, Origin3d, PowerPreference, PrimitiveTopology, - RenderPassColorAttachmentDescriptor, RenderPassDepthStencilAttachmentDescriptor, + AdapterDescriptor, AddressMode, Attachment, BindGroupLayoutBinding, BindingType, + BlendStateDescriptor, BorderColor, BufferDescriptor, BufferUsageFlags, Color, ColorWriteFlags, + CommandBufferDescriptor, CompareFunction, DepthStencilStateDescriptor, DeviceDescriptor, + Extensions, Extent3d, FilterMode, IndexFormat, InputStepMode, LoadOp, Origin3d, + PowerPreference, PrimitiveTopology, RenderPassColorAttachmentDescriptor, + RenderPassDepthStencilAttachmentDescriptor, SamplerDescriptor, ShaderAttributeIndex, ShaderModuleDescriptor, ShaderStage, ShaderStageFlags, StoreOp, SwapChainDescriptor, TextureDescriptor, TextureDimension, TextureFormat, TextureUsageFlags, TextureViewDescriptor, - SamplerDescriptor, AddressMode, FilterMode, CompareFunction, BorderColor, + VertexAttributeDescriptor, VertexFormat, }; pub struct Instance { @@ -162,6 +162,11 @@ pub struct RenderPipelineDescriptor<'a> { pub vertex_buffers: &'a [VertexBufferDescriptor<'a>], } +pub struct ComputePipelineDescriptor<'a> { + pub layout: &'a PipelineLayout, + pub compute_stage: PipelineStageDescriptor<'a>, +} + pub struct RenderPassDescriptor<'a> { pub color_attachments: &'a [RenderPassColorAttachmentDescriptor<&'a TextureView>], pub depth_stencil_attachment: @@ -210,7 +215,6 @@ impl<'a> TextureCopyView<'a> { } } - impl Instance { pub fn new() -> Self { Instance { @@ -273,14 +277,17 @@ impl Device { .map(|binding| wgn::Binding { binding: binding.binding, resource: match binding.resource { - BindingResource::Buffer { ref buffer, ref range } => { - wgn::BindingResource::Buffer(wgn::BufferBinding { - buffer: buffer.id, - offset: range.start, - size: range.end, - }) + BindingResource::Buffer { + ref buffer, + ref range, + } => wgn::BindingResource::Buffer(wgn::BufferBinding { + buffer: buffer.id, + offset: range.start, + size: range.end, + }), + BindingResource::Sampler(ref sampler) => { + wgn::BindingResource::Sampler(sampler.id) } - BindingResource::Sampler(ref sampler) => wgn::BindingResource::Sampler(sampler.id), BindingResource::TextureView(ref texture_view) => { wgn::BindingResource::TextureView(texture_view.id) } @@ -362,7 +369,8 @@ impl Device { .collect::>(); let temp_blend_states = desc.blend_states.iter().map(|bs| bs.id).collect::>(); - let temp_vertex_buffers = desc.vertex_buffers + let temp_vertex_buffers = desc + .vertex_buffers .iter() .map(|vbuf| wgn::VertexBufferDescriptor { stride: vbuf.stride, @@ -403,6 +411,25 @@ impl Device { } } + pub fn create_compute_pipeline(&self, desc: &ComputePipelineDescriptor) -> ComputePipeline { + let entry_point = CString::new(desc.compute_stage.entry_point).unwrap(); + let compute_stage = wgn::PipelineStageDescriptor { + module: desc.compute_stage.module.id, + stage: desc.compute_stage.stage, + entry_point: entry_point.as_ptr(), + }; + + ComputePipeline { + id: wgn::wgpu_device_create_compute_pipeline( + self.id, + &wgn::ComputePipelineDescriptor { + layout: desc.layout.id, + compute_stage, + }, + ), + } + } + pub fn create_buffer(&self, desc: &BufferDescriptor) -> Buffer { Buffer { id: wgn::wgpu_device_create_buffer(self.id, desc), @@ -651,7 +678,9 @@ impl SwapChain { pub fn get_next_texture(&mut self) -> SwapChainOutput { let output = wgn::wgpu_swap_chain_get_next_texture(self.id); SwapChainOutput { - texture: Texture { id: output.texture_id }, + texture: Texture { + id: output.texture_id, + }, view: TextureView { id: output.view_id }, swap_chain_id: &self.id, }