diff --git a/naga/src/back/wgsl/mod.rs b/naga/src/back/wgsl/mod.rs index 029e1fff0e..9430460dad 100644 --- a/naga/src/back/wgsl/mod.rs +++ b/naga/src/back/wgsl/mod.rs @@ -14,6 +14,8 @@ use thiserror::Error; pub use writer::{Writer, WriterFlags}; +use crate::common::wgsl; + #[derive(Error, Debug)] pub enum Error { #[error(transparent)] @@ -46,6 +48,20 @@ impl Error { } } +trait ToWgslIfImplemented { + fn to_wgsl_if_implemented(self) -> Result<&'static str, Error>; +} + +impl ToWgslIfImplemented for T +where + T: wgsl::TryToWgsl + core::fmt::Debug + Copy, +{ + fn to_wgsl_if_implemented(self) -> Result<&'static str, Error> { + self.try_to_wgsl() + .ok_or_else(|| Error::unsupported(T::DESCRIPTION, self)) + } +} + pub fn write_string( module: &crate::Module, info: &crate::valid::ModuleInfo, diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 5dfda9fdfb..7c6b3bf242 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -7,9 +7,11 @@ use alloc::{ use core::fmt::Write; use super::Error; +use super::ToWgslIfImplemented as _; use crate::back::wgsl::polyfill::InversePolyfill; use crate::{ back::{self, Baked}, + common::wgsl::{ToWgsl, TryToWgsl}, proc::{self, ExpressionKindTracker, NameKey}, valid, Handle, Module, ShaderStage, TypeInner, }; @@ -316,7 +318,7 @@ impl Writer { Attribute::Location(id) => write!(self.out, "@location({id}) ")?, Attribute::SecondBlendSource => write!(self.out, "@second_blend_source ")?, Attribute::BuiltIn(builtin_attrib) => { - let builtin = builtin_str(builtin_attrib)?; + let builtin = builtin_attrib.to_wgsl_if_implemented()?; write!(self.out, "@builtin({builtin}) ")?; } Attribute::Stage(shader_stage) => { @@ -339,24 +341,18 @@ impl Writer { Attribute::Invariant => write!(self.out, "@invariant ")?, Attribute::Interpolate(interpolation, sampling) => { if sampling.is_some() && sampling != Some(crate::Sampling::Center) { - write!( - self.out, - "@interpolate({}, {}) ", - interpolation_str( - interpolation.unwrap_or(crate::Interpolation::Perspective) - ), - sampling_str(sampling.unwrap_or(crate::Sampling::Center)) - )?; + let interpolation = interpolation + .unwrap_or(crate::Interpolation::Perspective) + .to_wgsl(); + let sampling = sampling.unwrap_or(crate::Sampling::Center).to_wgsl(); + write!(self.out, "@interpolate({interpolation}, {sampling}) ")?; } else if interpolation.is_some() && interpolation != Some(crate::Interpolation::Perspective) { - write!( - self.out, - "@interpolate({}) ", - interpolation_str( - interpolation.unwrap_or(crate::Interpolation::Perspective) - ) - )?; + let interpolation = interpolation + .unwrap_or(crate::Interpolation::Perspective) + .to_wgsl(); + write!(self.out, "@interpolate({interpolation}) ")?; } } }; @@ -455,7 +451,7 @@ impl Writer { Ic::Storage { format, access } => ( "storage_", "", - storage_format_str(format), + format.to_wgsl(), if access.contains(crate::StorageAccess::ATOMIC) { ",atomic" } else if access @@ -1679,98 +1675,19 @@ impl Writer { InversePolyfill(InversePolyfill), } - let function = match fun { - Mf::Abs => Function::Regular("abs"), - Mf::Min => Function::Regular("min"), - Mf::Max => Function::Regular("max"), - Mf::Clamp => Function::Regular("clamp"), - Mf::Saturate => Function::Regular("saturate"), - // trigonometry - Mf::Cos => Function::Regular("cos"), - Mf::Cosh => Function::Regular("cosh"), - Mf::Sin => Function::Regular("sin"), - Mf::Sinh => Function::Regular("sinh"), - Mf::Tan => Function::Regular("tan"), - Mf::Tanh => Function::Regular("tanh"), - Mf::Acos => Function::Regular("acos"), - Mf::Asin => Function::Regular("asin"), - Mf::Atan => Function::Regular("atan"), - Mf::Atan2 => Function::Regular("atan2"), - Mf::Asinh => Function::Regular("asinh"), - Mf::Acosh => Function::Regular("acosh"), - Mf::Atanh => Function::Regular("atanh"), - Mf::Radians => Function::Regular("radians"), - Mf::Degrees => Function::Regular("degrees"), - // decomposition - Mf::Ceil => Function::Regular("ceil"), - Mf::Floor => Function::Regular("floor"), - Mf::Round => Function::Regular("round"), - Mf::Fract => Function::Regular("fract"), - Mf::Trunc => Function::Regular("trunc"), - Mf::Modf => Function::Regular("modf"), - Mf::Frexp => Function::Regular("frexp"), - Mf::Ldexp => Function::Regular("ldexp"), - // exponent - Mf::Exp => Function::Regular("exp"), - Mf::Exp2 => Function::Regular("exp2"), - Mf::Log => Function::Regular("log"), - Mf::Log2 => Function::Regular("log2"), - Mf::Pow => Function::Regular("pow"), - // geometry - Mf::Dot => Function::Regular("dot"), - Mf::Cross => Function::Regular("cross"), - Mf::Distance => Function::Regular("distance"), - Mf::Length => Function::Regular("length"), - Mf::Normalize => Function::Regular("normalize"), - Mf::FaceForward => Function::Regular("faceForward"), - Mf::Reflect => Function::Regular("reflect"), - Mf::Refract => Function::Regular("refract"), - // computational - Mf::Sign => Function::Regular("sign"), - Mf::Fma => Function::Regular("fma"), - Mf::Mix => Function::Regular("mix"), - Mf::Step => Function::Regular("step"), - Mf::SmoothStep => Function::Regular("smoothstep"), - Mf::Sqrt => Function::Regular("sqrt"), - Mf::InverseSqrt => Function::Regular("inverseSqrt"), - Mf::Transpose => Function::Regular("transpose"), - Mf::Determinant => Function::Regular("determinant"), - Mf::QuantizeToF16 => Function::Regular("quantizeToF16"), - // bits - Mf::CountTrailingZeros => Function::Regular("countTrailingZeros"), - Mf::CountLeadingZeros => Function::Regular("countLeadingZeros"), - Mf::CountOneBits => Function::Regular("countOneBits"), - Mf::ReverseBits => Function::Regular("reverseBits"), - Mf::ExtractBits => Function::Regular("extractBits"), - Mf::InsertBits => Function::Regular("insertBits"), - Mf::FirstTrailingBit => Function::Regular("firstTrailingBit"), - Mf::FirstLeadingBit => Function::Regular("firstLeadingBit"), - // data packing - Mf::Pack4x8snorm => Function::Regular("pack4x8snorm"), - Mf::Pack4x8unorm => Function::Regular("pack4x8unorm"), - Mf::Pack2x16snorm => Function::Regular("pack2x16snorm"), - Mf::Pack2x16unorm => Function::Regular("pack2x16unorm"), - Mf::Pack2x16float => Function::Regular("pack2x16float"), - Mf::Pack4xI8 => Function::Regular("pack4xI8"), - Mf::Pack4xU8 => Function::Regular("pack4xU8"), - // data unpacking - Mf::Unpack4x8snorm => Function::Regular("unpack4x8snorm"), - Mf::Unpack4x8unorm => Function::Regular("unpack4x8unorm"), - Mf::Unpack2x16snorm => Function::Regular("unpack2x16snorm"), - Mf::Unpack2x16unorm => Function::Regular("unpack2x16unorm"), - Mf::Unpack2x16float => Function::Regular("unpack2x16float"), - Mf::Unpack4xI8 => Function::Regular("unpack4xI8"), - Mf::Unpack4xU8 => Function::Regular("unpack4xU8"), - Mf::Inverse => { - let ty = func_ctx.resolve_type(arg, &module.types); + let function = match fun.try_to_wgsl() { + Some(name) => Function::Regular(name), + None => match fun { + Mf::Inverse => { + let ty = func_ctx.resolve_type(arg, &module.types); + let Some(overload) = InversePolyfill::find_overload(ty) else { + return Err(Error::unsupported("math function", fun)); + }; - let Some(overload) = InversePolyfill::find_overload(ty) else { - return Err(Error::unsupported("math function", fun)); - }; - - Function::InversePolyfill(overload) - } - Mf::Outer => return Err(Error::unsupported("math function", fun)), + Function::InversePolyfill(overload) + } + _ => return Err(Error::unsupported("math function", fun)), + }, }; match function { @@ -1952,39 +1869,6 @@ impl Writer { } } -fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> { - use crate::BuiltIn as Bi; - - Ok(match built_in { - Bi::VertexIndex => "vertex_index", - Bi::InstanceIndex => "instance_index", - Bi::Position { .. } => "position", - Bi::FrontFacing => "front_facing", - Bi::FragDepth => "frag_depth", - Bi::LocalInvocationId => "local_invocation_id", - Bi::LocalInvocationIndex => "local_invocation_index", - Bi::GlobalInvocationId => "global_invocation_id", - Bi::WorkGroupId => "workgroup_id", - Bi::NumWorkGroups => "num_workgroups", - Bi::SampleIndex => "sample_index", - Bi::SampleMask => "sample_mask", - Bi::PrimitiveIndex => "primitive_index", - Bi::ViewIndex => "view_index", - Bi::NumSubgroups => "num_subgroups", - Bi::SubgroupId => "subgroup_id", - Bi::SubgroupSize => "subgroup_size", - Bi::SubgroupInvocationId => "subgroup_invocation_id", - Bi::BaseInstance - | Bi::BaseVertex - | Bi::ClipDistance - | Bi::CullDistance - | Bi::PointSize - | Bi::PointCoord - | Bi::WorkGroupSize - | Bi::DrawID => return Err(Error::unsupported("builtin", built_in)), - }) -} - const fn image_dimension_str(dim: crate::ImageDimension) -> &'static str { use crate::ImageDimension as IDim; @@ -2033,78 +1917,6 @@ const fn scalar_kind_str(scalar: crate::Scalar) -> &'static str { } } -const fn storage_format_str(format: crate::StorageFormat) -> &'static str { - use crate::StorageFormat as Sf; - - match format { - Sf::R8Unorm => "r8unorm", - Sf::R8Snorm => "r8snorm", - Sf::R8Uint => "r8uint", - Sf::R8Sint => "r8sint", - Sf::R16Uint => "r16uint", - Sf::R16Sint => "r16sint", - Sf::R16Float => "r16float", - Sf::Rg8Unorm => "rg8unorm", - Sf::Rg8Snorm => "rg8snorm", - Sf::Rg8Uint => "rg8uint", - Sf::Rg8Sint => "rg8sint", - Sf::R32Uint => "r32uint", - Sf::R32Sint => "r32sint", - Sf::R32Float => "r32float", - Sf::Rg16Uint => "rg16uint", - Sf::Rg16Sint => "rg16sint", - Sf::Rg16Float => "rg16float", - Sf::Rgba8Unorm => "rgba8unorm", - Sf::Rgba8Snorm => "rgba8snorm", - Sf::Rgba8Uint => "rgba8uint", - Sf::Rgba8Sint => "rgba8sint", - Sf::Bgra8Unorm => "bgra8unorm", - Sf::Rgb10a2Uint => "rgb10a2uint", - Sf::Rgb10a2Unorm => "rgb10a2unorm", - Sf::Rg11b10Ufloat => "rg11b10float", - Sf::R64Uint => "r64uint", - Sf::Rg32Uint => "rg32uint", - Sf::Rg32Sint => "rg32sint", - Sf::Rg32Float => "rg32float", - Sf::Rgba16Uint => "rgba16uint", - Sf::Rgba16Sint => "rgba16sint", - Sf::Rgba16Float => "rgba16float", - Sf::Rgba32Uint => "rgba32uint", - Sf::Rgba32Sint => "rgba32sint", - Sf::Rgba32Float => "rgba32float", - Sf::R16Unorm => "r16unorm", - Sf::R16Snorm => "r16snorm", - Sf::Rg16Unorm => "rg16unorm", - Sf::Rg16Snorm => "rg16snorm", - Sf::Rgba16Unorm => "rgba16unorm", - Sf::Rgba16Snorm => "rgba16snorm", - } -} - -/// Helper function that returns the string corresponding to the WGSL interpolation qualifier -const fn interpolation_str(interpolation: crate::Interpolation) -> &'static str { - use crate::Interpolation as I; - - match interpolation { - I::Perspective => "perspective", - I::Linear => "linear", - I::Flat => "flat", - } -} - -/// Return the WGSL auxiliary qualifier for the given sampling value. -const fn sampling_str(sampling: crate::Sampling) -> &'static str { - use crate::Sampling as S; - - match sampling { - S::Center => "", - S::Centroid => "centroid", - S::Sample => "sample", - S::First => "first", - S::Either => "either", - } -} - const fn address_space_str( space: crate::AddressSpace, ) -> (Option<&'static str>, Option<&'static str>) { diff --git a/naga/src/common/wgsl.rs b/naga/src/common/wgsl.rs index 449ec918bb..a0de0465d3 100644 --- a/naga/src/common/wgsl.rs +++ b/naga/src/common/wgsl.rs @@ -67,3 +67,241 @@ impl StandardFilterableTriggeringRule { } } } + +/// Types that can return the WGSL source representation of their +/// values as a `'static` string. +/// +/// This trait is specifically for types whose WGSL forms are simple +/// enough that they can always be returned as a static string. +/// +/// - If only some values have a WGSL representation, consider +/// implementing [`TryToWgsl`] instead. +/// +/// - If a type's WGSL form requires dynamic formatting, so that +/// returning a `&'static str` isn't feasible, consider implementing +/// [`std::fmt::Display`] on some wrapper type instead. +pub trait ToWgsl: Sized { + /// Return WGSL source code representation of `self`. + fn to_wgsl(self) -> &'static str; +} + +/// Types that may be able to return the WGSL source representation +/// for their values as a `'static' string. +/// +/// This trait is specifically for types whose values are either +/// simple enough that their WGSL form can be represented a static +/// string, or aren't representable in WGSL at all. +/// +/// - If all values in the type have `&'static str` representations in +/// WGSL, consider implementing [`ToWgsl`] instead. +/// +/// - If a type's WGSL form requires dynamic formatting, so that +/// returning a `&'static str` isn't feasible, consider implementing +/// [`std::fmt::Display`] on some wrapper type instead. +pub trait TryToWgsl: Sized { + /// Return the WGSL form of `self` as a `'static` string. + /// + /// If `self` doesn't have a representation in WGSL (standard or + /// as extended by Naga), then return `None`. + fn try_to_wgsl(self) -> Option<&'static str>; + + /// What kind of WGSL thing `Self` represents. + const DESCRIPTION: &'static str; +} + +impl TryToWgsl for crate::MathFunction { + const DESCRIPTION: &'static str = "math function"; + + fn try_to_wgsl(self) -> Option<&'static str> { + use crate::MathFunction as Mf; + + Some(match self { + Mf::Abs => "abs", + Mf::Min => "min", + Mf::Max => "max", + Mf::Clamp => "clamp", + Mf::Saturate => "saturate", + Mf::Cos => "cos", + Mf::Cosh => "cosh", + Mf::Sin => "sin", + Mf::Sinh => "sinh", + Mf::Tan => "tan", + Mf::Tanh => "tanh", + Mf::Acos => "acos", + Mf::Asin => "asin", + Mf::Atan => "atan", + Mf::Atan2 => "atan2", + Mf::Asinh => "asinh", + Mf::Acosh => "acosh", + Mf::Atanh => "atanh", + Mf::Radians => "radians", + Mf::Degrees => "degrees", + Mf::Ceil => "ceil", + Mf::Floor => "floor", + Mf::Round => "round", + Mf::Fract => "fract", + Mf::Trunc => "trunc", + Mf::Modf => "modf", + Mf::Frexp => "frexp", + Mf::Ldexp => "ldexp", + Mf::Exp => "exp", + Mf::Exp2 => "exp2", + Mf::Log => "log", + Mf::Log2 => "log2", + Mf::Pow => "pow", + Mf::Dot => "dot", + Mf::Cross => "cross", + Mf::Distance => "distance", + Mf::Length => "length", + Mf::Normalize => "normalize", + Mf::FaceForward => "faceForward", + Mf::Reflect => "reflect", + Mf::Refract => "refract", + Mf::Sign => "sign", + Mf::Fma => "fma", + Mf::Mix => "mix", + Mf::Step => "step", + Mf::SmoothStep => "smoothstep", + Mf::Sqrt => "sqrt", + Mf::InverseSqrt => "inverseSqrt", + Mf::Transpose => "transpose", + Mf::Determinant => "determinant", + Mf::QuantizeToF16 => "quantizeToF16", + Mf::CountTrailingZeros => "countTrailingZeros", + Mf::CountLeadingZeros => "countLeadingZeros", + Mf::CountOneBits => "countOneBits", + Mf::ReverseBits => "reverseBits", + Mf::ExtractBits => "extractBits", + Mf::InsertBits => "insertBits", + Mf::FirstTrailingBit => "firstTrailingBit", + Mf::FirstLeadingBit => "firstLeadingBit", + Mf::Pack4x8snorm => "pack4x8snorm", + Mf::Pack4x8unorm => "pack4x8unorm", + Mf::Pack2x16snorm => "pack2x16snorm", + Mf::Pack2x16unorm => "pack2x16unorm", + Mf::Pack2x16float => "pack2x16float", + Mf::Pack4xI8 => "pack4xI8", + Mf::Pack4xU8 => "pack4xU8", + Mf::Unpack4x8snorm => "unpack4x8snorm", + Mf::Unpack4x8unorm => "unpack4x8unorm", + Mf::Unpack2x16snorm => "unpack2x16snorm", + Mf::Unpack2x16unorm => "unpack2x16unorm", + Mf::Unpack2x16float => "unpack2x16float", + Mf::Unpack4xI8 => "unpack4xI8", + Mf::Unpack4xU8 => "unpack4xU8", + + // Non-standard math functions. + Mf::Inverse | Mf::Outer => return None, + }) + } +} + +impl TryToWgsl for crate::BuiltIn { + const DESCRIPTION: &'static str = "builtin value"; + + fn try_to_wgsl(self) -> Option<&'static str> { + use crate::BuiltIn as Bi; + Some(match self { + Bi::Position { .. } => "position", + Bi::ViewIndex => "view_index", + Bi::InstanceIndex => "instance_index", + Bi::VertexIndex => "vertex_index", + Bi::FragDepth => "frag_depth", + Bi::FrontFacing => "front_facing", + Bi::PrimitiveIndex => "primitive_index", + Bi::SampleIndex => "sample_index", + Bi::SampleMask => "sample_mask", + Bi::GlobalInvocationId => "global_invocation_id", + Bi::LocalInvocationId => "local_invocation_id", + Bi::LocalInvocationIndex => "local_invocation_index", + Bi::WorkGroupId => "workgroup_id", + Bi::NumWorkGroups => "num_workgroups", + Bi::NumSubgroups => "num_subgroups", + Bi::SubgroupId => "subgroup_id", + Bi::SubgroupSize => "subgroup_size", + Bi::SubgroupInvocationId => "subgroup_invocation_id", + + // Non-standard built-ins. + Bi::BaseInstance + | Bi::BaseVertex + | Bi::ClipDistance + | Bi::CullDistance + | Bi::PointSize + | Bi::DrawID + | Bi::PointCoord + | Bi::WorkGroupSize => return None, + }) + } +} + +impl ToWgsl for crate::Interpolation { + fn to_wgsl(self) -> &'static str { + match self { + crate::Interpolation::Perspective => "perspective", + crate::Interpolation::Linear => "linear", + crate::Interpolation::Flat => "flat", + } + } +} + +impl ToWgsl for crate::Sampling { + fn to_wgsl(self) -> &'static str { + match self { + crate::Sampling::Center => "center", + crate::Sampling::Centroid => "centroid", + crate::Sampling::Sample => "sample", + crate::Sampling::First => "first", + crate::Sampling::Either => "either", + } + } +} + +impl ToWgsl for crate::StorageFormat { + fn to_wgsl(self) -> &'static str { + use crate::StorageFormat as Sf; + + match self { + Sf::R8Unorm => "r8unorm", + Sf::R8Snorm => "r8snorm", + Sf::R8Uint => "r8uint", + Sf::R8Sint => "r8sint", + Sf::R16Uint => "r16uint", + Sf::R16Sint => "r16sint", + Sf::R16Float => "r16float", + Sf::Rg8Unorm => "rg8unorm", + Sf::Rg8Snorm => "rg8snorm", + Sf::Rg8Uint => "rg8uint", + Sf::Rg8Sint => "rg8sint", + Sf::R32Uint => "r32uint", + Sf::R32Sint => "r32sint", + Sf::R32Float => "r32float", + Sf::Rg16Uint => "rg16uint", + Sf::Rg16Sint => "rg16sint", + Sf::Rg16Float => "rg16float", + Sf::Rgba8Unorm => "rgba8unorm", + Sf::Rgba8Snorm => "rgba8snorm", + Sf::Rgba8Uint => "rgba8uint", + Sf::Rgba8Sint => "rgba8sint", + Sf::Bgra8Unorm => "bgra8unorm", + Sf::Rgb10a2Uint => "rgb10a2uint", + Sf::Rgb10a2Unorm => "rgb10a2unorm", + Sf::Rg11b10Ufloat => "rg11b10float", + Sf::R64Uint => "r64uint", + Sf::Rg32Uint => "rg32uint", + Sf::Rg32Sint => "rg32sint", + Sf::Rg32Float => "rg32float", + Sf::Rgba16Uint => "rgba16uint", + Sf::Rgba16Sint => "rgba16sint", + Sf::Rgba16Float => "rgba16float", + Sf::Rgba32Uint => "rgba32uint", + Sf::Rgba32Sint => "rgba32sint", + Sf::Rgba32Float => "rgba32float", + Sf::R16Unorm => "r16unorm", + Sf::R16Snorm => "r16snorm", + Sf::Rg16Unorm => "rg16unorm", + Sf::Rg16Snorm => "rg16snorm", + Sf::Rgba16Unorm => "rgba16unorm", + Sf::Rgba16Snorm => "rgba16snorm", + } + } +}