diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index 44685fb99e..419f4c4385 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -879,6 +879,8 @@ impl<'a, W: Write> Writer<'a, W> { | TypeInner::Struct { .. } | TypeInner::Image { .. } | TypeInner::Sampler { .. } + | TypeInner::AccelerationStructure + | TypeInner::RayQuery | TypeInner::BindingArray { .. } => { return Err(Error::Custom(format!("Unable to write type {inner:?}"))) } diff --git a/src/back/msl/writer.rs b/src/back/msl/writer.rs index ee23ca294a..68945315c9 100644 --- a/src/back/msl/writer.rs +++ b/src/back/msl/writer.rs @@ -194,6 +194,9 @@ impl<'a> Display for TypeContext<'a> { crate::TypeInner::Sampler { comparison: _ } => { write!(out, "{NAMESPACE}::sampler") } + crate::TypeInner::AccelerationStructure | crate::TypeInner::RayQuery => { + unreachable!("Ray queries are not supported yet"); + } crate::TypeInner::BindingArray { base, size } => { let base_tyname = Self { handle: base, @@ -485,7 +488,11 @@ impl crate::Type { // composite types are better to be aliased, regardless of the name Ti::Struct { .. } | Ti::Array { .. } => true, // handle types may be different, depending on the global var access, so we always inline them - Ti::Image { .. } | Ti::Sampler { .. } | Ti::BindingArray { .. } => false, + Ti::Image { .. } + | Ti::Sampler { .. } + | Ti::AccelerationStructure + | Ti::RayQuery + | Ti::BindingArray { .. } => false, } } } diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index f264c107d3..fbc53feedd 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -1017,7 +1017,9 @@ impl Writer { | crate::TypeInner::Pointer { .. } | crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Image { .. } - | crate::TypeInner::Sampler { .. } => unreachable!(), + | crate::TypeInner::Sampler { .. } + | crate::TypeInner::AccelerationStructure + | crate::TypeInner::RayQuery => unreachable!(), }; instruction.to_words(&mut self.logical_layout.declarations); diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index f3b157caa7..8dfb735af4 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -2245,6 +2245,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { class, }, ast::Type::Sampler { comparison } => crate::TypeInner::Sampler { comparison }, + ast::Type::AccelerationStructure => crate::TypeInner::AccelerationStructure, + ast::Type::RayQuery => crate::TypeInner::RayQuery, ast::Type::BindingArray { base, size } => { let base = self.resolve_ast_type(base, ctx.reborrow())?; diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 8ac82fe45e..eb21fae6c9 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -206,6 +206,8 @@ impl crate::TypeInner { format!("texture{class_suffix}{dim_suffix}{array_suffix}{type_in_brackets}") } Ti::Sampler { .. } => "sampler".to_string(), + Ti::AccelerationStructure => "acceleration_structure".to_string(), + Ti::RayQuery => "ray_query".to_string(), Ti::BindingArray { base, size, .. } => { let member_type = &types[base]; let base = member_type.name.as_deref().unwrap_or("unknown"); diff --git a/src/front/wgsl/parse/ast.rs b/src/front/wgsl/parse/ast.rs index 734d9769fe..a5da4a49cc 100644 --- a/src/front/wgsl/parse/ast.rs +++ b/src/front/wgsl/parse/ast.rs @@ -229,6 +229,8 @@ pub enum Type<'a> { Sampler { comparison: bool, }, + AccelerationStructure, + RayQuery, BindingArray { base: Handle>, size: ArraySize<'a>, diff --git a/src/front/wgsl/parse/mod.rs b/src/front/wgsl/parse/mod.rs index 7ff762d673..f082ec1c4e 100644 --- a/src/front/wgsl/parse/mod.rs +++ b/src/front/wgsl/parse/mod.rs @@ -1367,6 +1367,8 @@ impl Parser { class: crate::ImageClass::Storage { format, access }, } } + "acceleration_structure" => ast::Type::AccelerationStructure, + "ray_query" => ast::Type::RayQuery, _ => return Ok(None), })) } diff --git a/src/lib.rs b/src/lib.rs index c1b48b8991..f8491f3a36 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -721,6 +721,11 @@ pub enum TypeInner { /// Can be used to sample values from images. Sampler { comparison: bool }, + /// Opaque object representing an acceleration structure of geometry. + AccelerationStructure, + /// Locally used handle for ray queries. + RayQuery, + /// Array of bindings. /// /// A `BindingArray` represents an array where each element draws its value diff --git a/src/proc/layouter.rs b/src/proc/layouter.rs index db07f261a4..65369d1cc8 100644 --- a/src/proc/layouter.rs +++ b/src/proc/layouter.rs @@ -238,7 +238,11 @@ impl Layouter { alignment, } } - Ti::Image { .. } | Ti::Sampler { .. } | Ti::BindingArray { .. } => TypeLayout { + Ti::Image { .. } + | Ti::Sampler { .. } + | Ti::AccelerationStructure + | Ti::RayQuery + | Ti::BindingArray { .. } => TypeLayout { size, alignment: Alignment::ONE, }, diff --git a/src/proc/mod.rs b/src/proc/mod.rs index 6a8bfa03c7..a775272a19 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -134,7 +134,11 @@ impl super::TypeInner { count * stride } Self::Struct { span, .. } => span, - Self::Image { .. } | Self::Sampler { .. } | Self::BindingArray { .. } => 0, + Self::Image { .. } + | Self::Sampler { .. } + | Self::AccelerationStructure + | Self::RayQuery + | Self::BindingArray { .. } => 0, } } diff --git a/src/valid/expression.rs b/src/valid/expression.rs index af080fc183..9063eb0616 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1379,7 +1379,7 @@ impl super::Validator { _ => return Err(ExpressionError::InvalidCastArgument), }; let width = convert.unwrap_or(base_width); - if !self.check_width(kind, width) { + if self.check_width(kind, width).is_err() { return Err(ExpressionError::InvalidCastArgument); } ShaderStages::all() @@ -1390,7 +1390,7 @@ impl super::Validator { &crate::TypeInner::Scalar { kind: kind @ (crate::ScalarKind::Uint | crate::ScalarKind::Sint), width, - } => self.check_width(kind, width), + } => self.check_width(kind, width).is_ok(), _ => false, }; let good = match &module.types[ty].inner { diff --git a/src/valid/handles.rs b/src/valid/handles.rs index e3f9fe2531..871a73a219 100644 --- a/src/valid/handles.rs +++ b/src/valid/handles.rs @@ -76,7 +76,9 @@ impl super::Validator { | crate::TypeInner::ValuePointer { .. } | crate::TypeInner::Atomic { .. } | crate::TypeInner::Image { .. } - | crate::TypeInner::Sampler { .. } => (), + | crate::TypeInner::Sampler { .. } + | crate::TypeInner::AccelerationStructure + | crate::TypeInner::RayQuery => (), crate::TypeInner::Pointer { base, space: _ } => { this_handle.check_dep(base)?; } diff --git a/src/valid/mod.rs b/src/valid/mod.rs index eb92e8892d..6b3a2e1456 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -111,6 +111,8 @@ bitflags::bitflags! { const EARLY_DEPTH_TEST = 0x400; /// Support for [`Builtin::SampleIndex`] and [`Sampling::Sample`]. const MULTISAMPLED_SHADING = 0x800; + /// Support for ray queries and acceleration structures. + const RAY_QUERY = 0x1000; } } @@ -238,6 +240,8 @@ impl crate::TypeInner { Self::Array { .. } | Self::Image { .. } | Self::Sampler { .. } + | Self::AccelerationStructure + | Self::RayQuery | Self::BindingArray { .. } => false, } } @@ -302,7 +306,7 @@ impl Validator { let con = &constants[handle]; match con.inner { crate::ConstantInner::Scalar { width, ref value } => { - if !self.check_width(value.scalar_kind(), width) { + if self.check_width(value.scalar_kind(), width).is_err() { return Err(ConstantError::InvalidType); } } diff --git a/src/valid/type.rs b/src/valid/type.rs index 4fcc1a1c58..23f6ef4d1f 100644 --- a/src/valid/type.rs +++ b/src/valid/type.rs @@ -90,6 +90,8 @@ pub enum Disalignment { #[derive(Clone, Debug, thiserror::Error)] pub enum TypeError { + #[error("Capability {0:?} is required")] + MissingCapability(Capabilities), #[error("The {0:?} scalar width {1} is not supported")] InvalidWidth(crate::ScalarKind, crate::Bytes), #[error("The {0:?} scalar width {1} is not supported for an atomic")] @@ -203,13 +205,35 @@ impl TypeInfo { } impl super::Validator { - pub(super) const fn check_width(&self, kind: crate::ScalarKind, width: crate::Bytes) -> bool { - match kind { + fn require_type_capability(&self, capability: Capabilities) -> Result<(), TypeError> { + if self.capabilities.contains(capability) { + Ok(()) + } else { + Err(TypeError::MissingCapability(capability)) + } + } + + pub(super) fn check_width( + &self, + kind: crate::ScalarKind, + width: crate::Bytes, + ) -> Result<(), TypeError> { + let good = match kind { crate::ScalarKind::Bool => width == crate::BOOL_WIDTH, crate::ScalarKind::Float => { - width == 4 || (width == 8 && self.capabilities.contains(Capabilities::FLOAT64)) + if width == 8 { + self.require_type_capability(Capabilities::FLOAT64)?; + true + } else { + width == 4 + } } crate::ScalarKind::Sint | crate::ScalarKind::Uint => width == 4, + }; + if good { + Ok(()) + } else { + Err(TypeError::InvalidWidth(kind, width)) } } @@ -228,9 +252,7 @@ impl super::Validator { use crate::TypeInner as Ti; Ok(match types[handle].inner { Ti::Scalar { kind, width } => { - if !self.check_width(kind, width) { - return Err(TypeError::InvalidWidth(kind, width)); - } + self.check_width(kind, width)?; let shareable = if kind.is_numeric() { TypeFlags::IO_SHAREABLE | TypeFlags::HOST_SHAREABLE } else { @@ -247,9 +269,7 @@ impl super::Validator { ) } Ti::Vector { size, kind, width } => { - if !self.check_width(kind, width) { - return Err(TypeError::InvalidWidth(kind, width)); - } + self.check_width(kind, width)?; let shareable = if kind.is_numeric() { TypeFlags::IO_SHAREABLE | TypeFlags::HOST_SHAREABLE } else { @@ -271,9 +291,7 @@ impl super::Validator { rows, width, } => { - if !self.check_width(crate::ScalarKind::Float, width) { - return Err(TypeError::InvalidWidth(crate::ScalarKind::Float, width)); - } + self.check_width(crate::ScalarKind::Float, width)?; TypeInfo::new( TypeFlags::DATA | TypeFlags::SIZED @@ -355,9 +373,7 @@ impl super::Validator { // However, some cases are trivial: All our implicit base types // are DATA and SIZED, so we can never return // `InvalidPointerBase` or `InvalidPointerToUnsized`. - if !self.check_width(kind, width) { - return Err(TypeError::InvalidWidth(kind, width)); - } + self.check_width(kind, width)?; // `Validator::validate_function` actually checks the storage // space of pointer arguments explicitly before checking the @@ -606,6 +622,10 @@ impl super::Validator { Ti::Image { .. } | Ti::Sampler { .. } => { TypeInfo::new(TypeFlags::ARGUMENT, Alignment::ONE) } + Ti::AccelerationStructure | Ti::RayQuery => { + self.require_type_capability(Capabilities::RAY_QUERY)?; + TypeInfo::new(TypeFlags::empty(), Alignment::ONE) + } Ti::BindingArray { .. } => TypeInfo::new(TypeFlags::empty(), Alignment::ONE), }) } diff --git a/tests/in/ray-query.param.ron b/tests/in/ray-query.param.ron new file mode 100644 index 0000000000..9d8666954d --- /dev/null +++ b/tests/in/ray-query.param.ron @@ -0,0 +1,6 @@ +( + god_mode: true, + spv: ( + version: (1, 4), + ), +) diff --git a/tests/in/ray-query.wgsl b/tests/in/ray-query.wgsl new file mode 100644 index 0000000000..26023ff2f1 --- /dev/null +++ b/tests/in/ray-query.wgsl @@ -0,0 +1,17 @@ +var acc_struct: acceleration_structure; + +struct Output { + visible: u32, +} +var output: Output; + +@compute +fn main() { + var rq: ray_query; + + rayQueryInitialize(rq, acceleration_structure, RAY_FLAGS_TERMINATE_ON_FIRST_HIT, 0xFF, vec3(0.0), 0.1, vec3(0.0, 1.0, 0.0), 100.0); + + rayQueryProceed(rq); + + output.visible = rayQueryGetCommittedIntersectionType(rq) == RAY_QUERY_COMMITTED_INTERSECTION_NONE; +} diff --git a/tests/snapshots.rs b/tests/snapshots.rs index 691c074a93..df94130a70 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -571,6 +571,7 @@ fn convert_wgsl() { ("sprite", Targets::SPIRV), ("force_point_size_vertex_shader_webgl", Targets::GLSL), ("invariant", Targets::GLSL), + ("ray-query", Targets::SPIRV), ]; for &(name, targets) in inputs.iter() {