[naga wgsl] Define ToWgsl and TryToWgsl traits.

Define new traits in `common::wgsl`, `ToWgsl` and `TryToWgsl`, for
getting the WGSL representation of some Naga IR types as `&'static
str` values:

- `MathFunction`
- `BuiltIn`
- `Interpolation`
- `Sampling`
- `StorageFormat`

Use these functions in the WGSL backend, taking advantage of
`TryToWgsl` to consolidate error reporting.
This commit is contained in:
Jim Blandy
2025-03-03 12:22:02 -08:00
parent 2a364d2bfa
commit 34ffbee1b7
3 changed files with 279 additions and 213 deletions

View File

@@ -14,6 +14,8 @@ use thiserror::Error;
pub use writer::{Writer, WriterFlags};
use crate::common::wgsl;
#[derive(Error, Debug)]
pub enum Error {
#[error(transparent)]
@@ -46,6 +48,20 @@ impl Error {
}
}
trait ToWgslIfImplemented {
fn to_wgsl_if_implemented(self) -> Result<&'static str, Error>;
}
impl<T> ToWgslIfImplemented for T
where
T: wgsl::TryToWgsl + core::fmt::Debug + Copy,
{
fn to_wgsl_if_implemented(self) -> Result<&'static str, Error> {
self.try_to_wgsl()
.ok_or_else(|| Error::unsupported(T::DESCRIPTION, self))
}
}
pub fn write_string(
module: &crate::Module,
info: &crate::valid::ModuleInfo,

View File

@@ -7,9 +7,11 @@ use alloc::{
use core::fmt::Write;
use super::Error;
use super::ToWgslIfImplemented as _;
use crate::back::wgsl::polyfill::InversePolyfill;
use crate::{
back::{self, Baked},
common::wgsl::{ToWgsl, TryToWgsl},
proc::{self, ExpressionKindTracker, NameKey},
valid, Handle, Module, ShaderStage, TypeInner,
};
@@ -316,7 +318,7 @@ impl<W: Write> Writer<W> {
Attribute::Location(id) => write!(self.out, "@location({id}) ")?,
Attribute::SecondBlendSource => write!(self.out, "@second_blend_source ")?,
Attribute::BuiltIn(builtin_attrib) => {
let builtin = builtin_str(builtin_attrib)?;
let builtin = builtin_attrib.to_wgsl_if_implemented()?;
write!(self.out, "@builtin({builtin}) ")?;
}
Attribute::Stage(shader_stage) => {
@@ -339,24 +341,18 @@ impl<W: Write> Writer<W> {
Attribute::Invariant => write!(self.out, "@invariant ")?,
Attribute::Interpolate(interpolation, sampling) => {
if sampling.is_some() && sampling != Some(crate::Sampling::Center) {
write!(
self.out,
"@interpolate({}, {}) ",
interpolation_str(
interpolation.unwrap_or(crate::Interpolation::Perspective)
),
sampling_str(sampling.unwrap_or(crate::Sampling::Center))
)?;
let interpolation = interpolation
.unwrap_or(crate::Interpolation::Perspective)
.to_wgsl();
let sampling = sampling.unwrap_or(crate::Sampling::Center).to_wgsl();
write!(self.out, "@interpolate({interpolation}, {sampling}) ")?;
} else if interpolation.is_some()
&& interpolation != Some(crate::Interpolation::Perspective)
{
write!(
self.out,
"@interpolate({}) ",
interpolation_str(
interpolation.unwrap_or(crate::Interpolation::Perspective)
)
)?;
let interpolation = interpolation
.unwrap_or(crate::Interpolation::Perspective)
.to_wgsl();
write!(self.out, "@interpolate({interpolation}) ")?;
}
}
};
@@ -455,7 +451,7 @@ impl<W: Write> Writer<W> {
Ic::Storage { format, access } => (
"storage_",
"",
storage_format_str(format),
format.to_wgsl(),
if access.contains(crate::StorageAccess::ATOMIC) {
",atomic"
} else if access
@@ -1679,98 +1675,19 @@ impl<W: Write> Writer<W> {
InversePolyfill(InversePolyfill),
}
let function = match fun {
Mf::Abs => Function::Regular("abs"),
Mf::Min => Function::Regular("min"),
Mf::Max => Function::Regular("max"),
Mf::Clamp => Function::Regular("clamp"),
Mf::Saturate => Function::Regular("saturate"),
// trigonometry
Mf::Cos => Function::Regular("cos"),
Mf::Cosh => Function::Regular("cosh"),
Mf::Sin => Function::Regular("sin"),
Mf::Sinh => Function::Regular("sinh"),
Mf::Tan => Function::Regular("tan"),
Mf::Tanh => Function::Regular("tanh"),
Mf::Acos => Function::Regular("acos"),
Mf::Asin => Function::Regular("asin"),
Mf::Atan => Function::Regular("atan"),
Mf::Atan2 => Function::Regular("atan2"),
Mf::Asinh => Function::Regular("asinh"),
Mf::Acosh => Function::Regular("acosh"),
Mf::Atanh => Function::Regular("atanh"),
Mf::Radians => Function::Regular("radians"),
Mf::Degrees => Function::Regular("degrees"),
// decomposition
Mf::Ceil => Function::Regular("ceil"),
Mf::Floor => Function::Regular("floor"),
Mf::Round => Function::Regular("round"),
Mf::Fract => Function::Regular("fract"),
Mf::Trunc => Function::Regular("trunc"),
Mf::Modf => Function::Regular("modf"),
Mf::Frexp => Function::Regular("frexp"),
Mf::Ldexp => Function::Regular("ldexp"),
// exponent
Mf::Exp => Function::Regular("exp"),
Mf::Exp2 => Function::Regular("exp2"),
Mf::Log => Function::Regular("log"),
Mf::Log2 => Function::Regular("log2"),
Mf::Pow => Function::Regular("pow"),
// geometry
Mf::Dot => Function::Regular("dot"),
Mf::Cross => Function::Regular("cross"),
Mf::Distance => Function::Regular("distance"),
Mf::Length => Function::Regular("length"),
Mf::Normalize => Function::Regular("normalize"),
Mf::FaceForward => Function::Regular("faceForward"),
Mf::Reflect => Function::Regular("reflect"),
Mf::Refract => Function::Regular("refract"),
// computational
Mf::Sign => Function::Regular("sign"),
Mf::Fma => Function::Regular("fma"),
Mf::Mix => Function::Regular("mix"),
Mf::Step => Function::Regular("step"),
Mf::SmoothStep => Function::Regular("smoothstep"),
Mf::Sqrt => Function::Regular("sqrt"),
Mf::InverseSqrt => Function::Regular("inverseSqrt"),
Mf::Transpose => Function::Regular("transpose"),
Mf::Determinant => Function::Regular("determinant"),
Mf::QuantizeToF16 => Function::Regular("quantizeToF16"),
// bits
Mf::CountTrailingZeros => Function::Regular("countTrailingZeros"),
Mf::CountLeadingZeros => Function::Regular("countLeadingZeros"),
Mf::CountOneBits => Function::Regular("countOneBits"),
Mf::ReverseBits => Function::Regular("reverseBits"),
Mf::ExtractBits => Function::Regular("extractBits"),
Mf::InsertBits => Function::Regular("insertBits"),
Mf::FirstTrailingBit => Function::Regular("firstTrailingBit"),
Mf::FirstLeadingBit => Function::Regular("firstLeadingBit"),
// data packing
Mf::Pack4x8snorm => Function::Regular("pack4x8snorm"),
Mf::Pack4x8unorm => Function::Regular("pack4x8unorm"),
Mf::Pack2x16snorm => Function::Regular("pack2x16snorm"),
Mf::Pack2x16unorm => Function::Regular("pack2x16unorm"),
Mf::Pack2x16float => Function::Regular("pack2x16float"),
Mf::Pack4xI8 => Function::Regular("pack4xI8"),
Mf::Pack4xU8 => Function::Regular("pack4xU8"),
// data unpacking
Mf::Unpack4x8snorm => Function::Regular("unpack4x8snorm"),
Mf::Unpack4x8unorm => Function::Regular("unpack4x8unorm"),
Mf::Unpack2x16snorm => Function::Regular("unpack2x16snorm"),
Mf::Unpack2x16unorm => Function::Regular("unpack2x16unorm"),
Mf::Unpack2x16float => Function::Regular("unpack2x16float"),
Mf::Unpack4xI8 => Function::Regular("unpack4xI8"),
Mf::Unpack4xU8 => Function::Regular("unpack4xU8"),
Mf::Inverse => {
let ty = func_ctx.resolve_type(arg, &module.types);
let function = match fun.try_to_wgsl() {
Some(name) => Function::Regular(name),
None => match fun {
Mf::Inverse => {
let ty = func_ctx.resolve_type(arg, &module.types);
let Some(overload) = InversePolyfill::find_overload(ty) else {
return Err(Error::unsupported("math function", fun));
};
let Some(overload) = InversePolyfill::find_overload(ty) else {
return Err(Error::unsupported("math function", fun));
};
Function::InversePolyfill(overload)
}
Mf::Outer => return Err(Error::unsupported("math function", fun)),
Function::InversePolyfill(overload)
}
_ => return Err(Error::unsupported("math function", fun)),
},
};
match function {
@@ -1952,39 +1869,6 @@ impl<W: Write> Writer<W> {
}
}
fn builtin_str(built_in: crate::BuiltIn) -> Result<&'static str, Error> {
use crate::BuiltIn as Bi;
Ok(match built_in {
Bi::VertexIndex => "vertex_index",
Bi::InstanceIndex => "instance_index",
Bi::Position { .. } => "position",
Bi::FrontFacing => "front_facing",
Bi::FragDepth => "frag_depth",
Bi::LocalInvocationId => "local_invocation_id",
Bi::LocalInvocationIndex => "local_invocation_index",
Bi::GlobalInvocationId => "global_invocation_id",
Bi::WorkGroupId => "workgroup_id",
Bi::NumWorkGroups => "num_workgroups",
Bi::SampleIndex => "sample_index",
Bi::SampleMask => "sample_mask",
Bi::PrimitiveIndex => "primitive_index",
Bi::ViewIndex => "view_index",
Bi::NumSubgroups => "num_subgroups",
Bi::SubgroupId => "subgroup_id",
Bi::SubgroupSize => "subgroup_size",
Bi::SubgroupInvocationId => "subgroup_invocation_id",
Bi::BaseInstance
| Bi::BaseVertex
| Bi::ClipDistance
| Bi::CullDistance
| Bi::PointSize
| Bi::PointCoord
| Bi::WorkGroupSize
| Bi::DrawID => return Err(Error::unsupported("builtin", built_in)),
})
}
const fn image_dimension_str(dim: crate::ImageDimension) -> &'static str {
use crate::ImageDimension as IDim;
@@ -2033,78 +1917,6 @@ const fn scalar_kind_str(scalar: crate::Scalar) -> &'static str {
}
}
const fn storage_format_str(format: crate::StorageFormat) -> &'static str {
use crate::StorageFormat as Sf;
match format {
Sf::R8Unorm => "r8unorm",
Sf::R8Snorm => "r8snorm",
Sf::R8Uint => "r8uint",
Sf::R8Sint => "r8sint",
Sf::R16Uint => "r16uint",
Sf::R16Sint => "r16sint",
Sf::R16Float => "r16float",
Sf::Rg8Unorm => "rg8unorm",
Sf::Rg8Snorm => "rg8snorm",
Sf::Rg8Uint => "rg8uint",
Sf::Rg8Sint => "rg8sint",
Sf::R32Uint => "r32uint",
Sf::R32Sint => "r32sint",
Sf::R32Float => "r32float",
Sf::Rg16Uint => "rg16uint",
Sf::Rg16Sint => "rg16sint",
Sf::Rg16Float => "rg16float",
Sf::Rgba8Unorm => "rgba8unorm",
Sf::Rgba8Snorm => "rgba8snorm",
Sf::Rgba8Uint => "rgba8uint",
Sf::Rgba8Sint => "rgba8sint",
Sf::Bgra8Unorm => "bgra8unorm",
Sf::Rgb10a2Uint => "rgb10a2uint",
Sf::Rgb10a2Unorm => "rgb10a2unorm",
Sf::Rg11b10Ufloat => "rg11b10float",
Sf::R64Uint => "r64uint",
Sf::Rg32Uint => "rg32uint",
Sf::Rg32Sint => "rg32sint",
Sf::Rg32Float => "rg32float",
Sf::Rgba16Uint => "rgba16uint",
Sf::Rgba16Sint => "rgba16sint",
Sf::Rgba16Float => "rgba16float",
Sf::Rgba32Uint => "rgba32uint",
Sf::Rgba32Sint => "rgba32sint",
Sf::Rgba32Float => "rgba32float",
Sf::R16Unorm => "r16unorm",
Sf::R16Snorm => "r16snorm",
Sf::Rg16Unorm => "rg16unorm",
Sf::Rg16Snorm => "rg16snorm",
Sf::Rgba16Unorm => "rgba16unorm",
Sf::Rgba16Snorm => "rgba16snorm",
}
}
/// Helper function that returns the string corresponding to the WGSL interpolation qualifier
const fn interpolation_str(interpolation: crate::Interpolation) -> &'static str {
use crate::Interpolation as I;
match interpolation {
I::Perspective => "perspective",
I::Linear => "linear",
I::Flat => "flat",
}
}
/// Return the WGSL auxiliary qualifier for the given sampling value.
const fn sampling_str(sampling: crate::Sampling) -> &'static str {
use crate::Sampling as S;
match sampling {
S::Center => "",
S::Centroid => "centroid",
S::Sample => "sample",
S::First => "first",
S::Either => "either",
}
}
const fn address_space_str(
space: crate::AddressSpace,
) -> (Option<&'static str>, Option<&'static str>) {

View File

@@ -67,3 +67,241 @@ impl StandardFilterableTriggeringRule {
}
}
}
/// Types that can return the WGSL source representation of their
/// values as a `'static` string.
///
/// This trait is specifically for types whose WGSL forms are simple
/// enough that they can always be returned as a static string.
///
/// - If only some values have a WGSL representation, consider
/// implementing [`TryToWgsl`] instead.
///
/// - If a type's WGSL form requires dynamic formatting, so that
/// returning a `&'static str` isn't feasible, consider implementing
/// [`std::fmt::Display`] on some wrapper type instead.
pub trait ToWgsl: Sized {
/// Return WGSL source code representation of `self`.
fn to_wgsl(self) -> &'static str;
}
/// Types that may be able to return the WGSL source representation
/// for their values as a `'static' string.
///
/// This trait is specifically for types whose values are either
/// simple enough that their WGSL form can be represented a static
/// string, or aren't representable in WGSL at all.
///
/// - If all values in the type have `&'static str` representations in
/// WGSL, consider implementing [`ToWgsl`] instead.
///
/// - If a type's WGSL form requires dynamic formatting, so that
/// returning a `&'static str` isn't feasible, consider implementing
/// [`std::fmt::Display`] on some wrapper type instead.
pub trait TryToWgsl: Sized {
/// Return the WGSL form of `self` as a `'static` string.
///
/// If `self` doesn't have a representation in WGSL (standard or
/// as extended by Naga), then return `None`.
fn try_to_wgsl(self) -> Option<&'static str>;
/// What kind of WGSL thing `Self` represents.
const DESCRIPTION: &'static str;
}
impl TryToWgsl for crate::MathFunction {
const DESCRIPTION: &'static str = "math function";
fn try_to_wgsl(self) -> Option<&'static str> {
use crate::MathFunction as Mf;
Some(match self {
Mf::Abs => "abs",
Mf::Min => "min",
Mf::Max => "max",
Mf::Clamp => "clamp",
Mf::Saturate => "saturate",
Mf::Cos => "cos",
Mf::Cosh => "cosh",
Mf::Sin => "sin",
Mf::Sinh => "sinh",
Mf::Tan => "tan",
Mf::Tanh => "tanh",
Mf::Acos => "acos",
Mf::Asin => "asin",
Mf::Atan => "atan",
Mf::Atan2 => "atan2",
Mf::Asinh => "asinh",
Mf::Acosh => "acosh",
Mf::Atanh => "atanh",
Mf::Radians => "radians",
Mf::Degrees => "degrees",
Mf::Ceil => "ceil",
Mf::Floor => "floor",
Mf::Round => "round",
Mf::Fract => "fract",
Mf::Trunc => "trunc",
Mf::Modf => "modf",
Mf::Frexp => "frexp",
Mf::Ldexp => "ldexp",
Mf::Exp => "exp",
Mf::Exp2 => "exp2",
Mf::Log => "log",
Mf::Log2 => "log2",
Mf::Pow => "pow",
Mf::Dot => "dot",
Mf::Cross => "cross",
Mf::Distance => "distance",
Mf::Length => "length",
Mf::Normalize => "normalize",
Mf::FaceForward => "faceForward",
Mf::Reflect => "reflect",
Mf::Refract => "refract",
Mf::Sign => "sign",
Mf::Fma => "fma",
Mf::Mix => "mix",
Mf::Step => "step",
Mf::SmoothStep => "smoothstep",
Mf::Sqrt => "sqrt",
Mf::InverseSqrt => "inverseSqrt",
Mf::Transpose => "transpose",
Mf::Determinant => "determinant",
Mf::QuantizeToF16 => "quantizeToF16",
Mf::CountTrailingZeros => "countTrailingZeros",
Mf::CountLeadingZeros => "countLeadingZeros",
Mf::CountOneBits => "countOneBits",
Mf::ReverseBits => "reverseBits",
Mf::ExtractBits => "extractBits",
Mf::InsertBits => "insertBits",
Mf::FirstTrailingBit => "firstTrailingBit",
Mf::FirstLeadingBit => "firstLeadingBit",
Mf::Pack4x8snorm => "pack4x8snorm",
Mf::Pack4x8unorm => "pack4x8unorm",
Mf::Pack2x16snorm => "pack2x16snorm",
Mf::Pack2x16unorm => "pack2x16unorm",
Mf::Pack2x16float => "pack2x16float",
Mf::Pack4xI8 => "pack4xI8",
Mf::Pack4xU8 => "pack4xU8",
Mf::Unpack4x8snorm => "unpack4x8snorm",
Mf::Unpack4x8unorm => "unpack4x8unorm",
Mf::Unpack2x16snorm => "unpack2x16snorm",
Mf::Unpack2x16unorm => "unpack2x16unorm",
Mf::Unpack2x16float => "unpack2x16float",
Mf::Unpack4xI8 => "unpack4xI8",
Mf::Unpack4xU8 => "unpack4xU8",
// Non-standard math functions.
Mf::Inverse | Mf::Outer => return None,
})
}
}
impl TryToWgsl for crate::BuiltIn {
const DESCRIPTION: &'static str = "builtin value";
fn try_to_wgsl(self) -> Option<&'static str> {
use crate::BuiltIn as Bi;
Some(match self {
Bi::Position { .. } => "position",
Bi::ViewIndex => "view_index",
Bi::InstanceIndex => "instance_index",
Bi::VertexIndex => "vertex_index",
Bi::FragDepth => "frag_depth",
Bi::FrontFacing => "front_facing",
Bi::PrimitiveIndex => "primitive_index",
Bi::SampleIndex => "sample_index",
Bi::SampleMask => "sample_mask",
Bi::GlobalInvocationId => "global_invocation_id",
Bi::LocalInvocationId => "local_invocation_id",
Bi::LocalInvocationIndex => "local_invocation_index",
Bi::WorkGroupId => "workgroup_id",
Bi::NumWorkGroups => "num_workgroups",
Bi::NumSubgroups => "num_subgroups",
Bi::SubgroupId => "subgroup_id",
Bi::SubgroupSize => "subgroup_size",
Bi::SubgroupInvocationId => "subgroup_invocation_id",
// Non-standard built-ins.
Bi::BaseInstance
| Bi::BaseVertex
| Bi::ClipDistance
| Bi::CullDistance
| Bi::PointSize
| Bi::DrawID
| Bi::PointCoord
| Bi::WorkGroupSize => return None,
})
}
}
impl ToWgsl for crate::Interpolation {
fn to_wgsl(self) -> &'static str {
match self {
crate::Interpolation::Perspective => "perspective",
crate::Interpolation::Linear => "linear",
crate::Interpolation::Flat => "flat",
}
}
}
impl ToWgsl for crate::Sampling {
fn to_wgsl(self) -> &'static str {
match self {
crate::Sampling::Center => "center",
crate::Sampling::Centroid => "centroid",
crate::Sampling::Sample => "sample",
crate::Sampling::First => "first",
crate::Sampling::Either => "either",
}
}
}
impl ToWgsl for crate::StorageFormat {
fn to_wgsl(self) -> &'static str {
use crate::StorageFormat as Sf;
match self {
Sf::R8Unorm => "r8unorm",
Sf::R8Snorm => "r8snorm",
Sf::R8Uint => "r8uint",
Sf::R8Sint => "r8sint",
Sf::R16Uint => "r16uint",
Sf::R16Sint => "r16sint",
Sf::R16Float => "r16float",
Sf::Rg8Unorm => "rg8unorm",
Sf::Rg8Snorm => "rg8snorm",
Sf::Rg8Uint => "rg8uint",
Sf::Rg8Sint => "rg8sint",
Sf::R32Uint => "r32uint",
Sf::R32Sint => "r32sint",
Sf::R32Float => "r32float",
Sf::Rg16Uint => "rg16uint",
Sf::Rg16Sint => "rg16sint",
Sf::Rg16Float => "rg16float",
Sf::Rgba8Unorm => "rgba8unorm",
Sf::Rgba8Snorm => "rgba8snorm",
Sf::Rgba8Uint => "rgba8uint",
Sf::Rgba8Sint => "rgba8sint",
Sf::Bgra8Unorm => "bgra8unorm",
Sf::Rgb10a2Uint => "rgb10a2uint",
Sf::Rgb10a2Unorm => "rgb10a2unorm",
Sf::Rg11b10Ufloat => "rg11b10float",
Sf::R64Uint => "r64uint",
Sf::Rg32Uint => "rg32uint",
Sf::Rg32Sint => "rg32sint",
Sf::Rg32Float => "rg32float",
Sf::Rgba16Uint => "rgba16uint",
Sf::Rgba16Sint => "rgba16sint",
Sf::Rgba16Float => "rgba16float",
Sf::Rgba32Uint => "rgba32uint",
Sf::Rgba32Sint => "rgba32sint",
Sf::Rgba32Float => "rgba32float",
Sf::R16Unorm => "r16unorm",
Sf::R16Snorm => "r16snorm",
Sf::Rg16Unorm => "rg16unorm",
Sf::Rg16Snorm => "rg16snorm",
Sf::Rgba16Unorm => "rgba16unorm",
Sf::Rgba16Snorm => "rgba16snorm",
}
}
}