diff --git a/naga/src/common/wgsl/types.rs b/naga/src/common/wgsl/types.rs index 4b2ee5ed0d..e47ec2bab6 100644 --- a/naga/src/common/wgsl/types.rs +++ b/naga/src/common/wgsl/types.rs @@ -2,7 +2,7 @@ use super::{address_space_str, ToWgsl, TryToWgsl}; use crate::common; -use crate::{Handle, TypeInner}; +use crate::{Handle, Scalar, TypeInner}; use core::fmt::Write; @@ -61,6 +61,14 @@ pub trait TypeContext { fn write_type_inner(&self, inner: &TypeInner, out: &mut W) -> core::fmt::Result { try_write_type_inner(self, inner, out) } + + /// Write the [`Scalar`] `scalar` as a WGSL type. + fn write_scalar(&self, scalar: Scalar, out: &mut W) -> core::fmt::Result { + match scalar.try_to_wgsl() { + Some(string) => out.write_str(string), + None => unreachable!("validation should have forbidden Scalar: {scalar:?}"), + } + } } fn try_write_type_inner(ctx: &C, inner: &TypeInner, out: &mut W) -> core::fmt::Result @@ -68,22 +76,12 @@ where C: TypeContext + ?Sized, W: Write, { - fn unwrap_to_wgsl(value: T) -> &'static str { - value.try_to_wgsl().unwrap_or_else(|| { - unreachable!( - "validation should have forbidden {}: {value:?}", - T::DESCRIPTION - ); - }) - } - match *inner { - TypeInner::Vector { size, scalar } => write!( - out, - "vec{}<{}>", - common::vector_size_str(size), - unwrap_to_wgsl(scalar), - )?, + TypeInner::Vector { size, scalar } => { + write!(out, "vec{}<", common::vector_size_str(size))?; + ctx.write_scalar(scalar, out)?; + out.write_str(">")?; + } TypeInner::Sampler { comparison: false } => { write!(out, "sampler")?; } @@ -103,11 +101,9 @@ where match class { Ic::Sampled { kind, multi } => { let multisampled_str = if multi { "multisampled_" } else { "" }; - let type_str = unwrap_to_wgsl(crate::Scalar { kind, width: 4 }); - write!( - out, - "texture_{multisampled_str}{dim_str}{arrayed_str}<{type_str}>" - )?; + write!(out, "texture_{multisampled_str}{dim_str}{arrayed_str}<")?; + ctx.write_scalar(Scalar { kind, width: 4 }, out)?; + out.write_str(">")?; } Ic::Depth { multi } => { let multisampled_str = if multi { "multisampled_" } else { "" }; @@ -137,10 +133,12 @@ where } } TypeInner::Scalar(scalar) => { - write!(out, "{}", unwrap_to_wgsl(scalar))?; + ctx.write_scalar(scalar, out)?; } TypeInner::Atomic(scalar) => { - write!(out, "atomic<{}>", unwrap_to_wgsl(scalar))?; + out.write_str("atomic<")?; + ctx.write_scalar(scalar, out)?; + out.write_str(">")?; } TypeInner::Array { base, @@ -189,11 +187,12 @@ where } => { write!( out, - "mat{}x{}<{}>", + "mat{}x{}<", common::vector_size_str(columns), common::vector_size_str(rows), - unwrap_to_wgsl(scalar) )?; + ctx.write_scalar(scalar, out)?; + out.write_str(">")?; } TypeInner::Pointer { base, space } => { let (address, maybe_access) = address_space_str(space); @@ -218,7 +217,8 @@ where } => { let (address, maybe_access) = address_space_str(space); if let Some(space) = address { - write!(out, "ptr<{}, {}", space, unwrap_to_wgsl(scalar))?; + write!(out, "ptr<{}, ", space)?; + ctx.write_scalar(scalar, out)?; if let Some(access) = maybe_access { write!(out, ", {access}")?; } @@ -234,13 +234,9 @@ where } => { let (address, maybe_access) = address_space_str(space); if let Some(space) = address { - write!( - out, - "ptr<{}, vec{}<{}>", - space, - common::vector_size_str(size), - unwrap_to_wgsl(scalar) - )?; + write!(out, "ptr<{}, vec{}<", space, common::vector_size_str(size),)?; + ctx.write_scalar(scalar, out)?; + out.write_str(">")?; if let Some(access) = maybe_access { write!(out, ", {access}")?; }