From 276c978b70cdb8a7a6ab2520651b8acd04dc36bb Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Thu, 9 Nov 2023 09:23:15 -0800 Subject: [PATCH] [naga] Introduce `ScalarKind::AbstractInt` and `AbstractFloat`. Introduce new variants of `naga::ScalarKind`, `AbstractInt` and `AbstractFloat`, for representing WGSL abstract types. --- naga/src/back/glsl/mod.rs | 11 +++++++++ naga/src/back/hlsl/conv.rs | 5 +++- naga/src/back/msl/writer.rs | 4 ++++ naga/src/back/spv/block.rs | 2 +- naga/src/back/spv/image.rs | 4 ++++ naga/src/back/spv/writer.rs | 8 +++++++ naga/src/front/glsl/types.rs | 2 +- naga/src/front/wgsl/to_wgsl.rs | 2 ++ naga/src/lib.rs | 10 ++++++++ naga/src/proc/mod.rs | 6 ++++- naga/src/valid/expression.rs | 42 ++++++++++++++++++++++++++-------- naga/src/valid/type.rs | 11 ++++++++- wgpu-core/src/validation.rs | 4 +++- 13 files changed, 95 insertions(+), 16 deletions(-) diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index cd3075f70a..d08a0c02c2 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -3555,6 +3555,9 @@ impl<'a, W: Write> Writer<'a, W> { (Sk::Sint | Sk::Uint | Sk::Float, Sk::Bool, None) => { write!(self.out, "bool")? } + + (Sk::AbstractInt | Sk::AbstractFloat, _, _) + | (_, Sk::AbstractInt | Sk::AbstractFloat, _) => unreachable!(), }; write!(self.out, "(")?; @@ -4117,6 +4120,11 @@ impl<'a, W: Write> Writer<'a, W> { crate::ScalarKind::Uint => write!(self.out, "0u")?, crate::ScalarKind::Float => write!(self.out, "0.0")?, crate::ScalarKind::Sint => write!(self.out, "0")?, + crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => { + return Err(Error::Custom( + "Abstract types should not appear in IR presented to backends".to_string(), + )) + } } Ok(()) @@ -4345,6 +4353,9 @@ const fn glsl_scalar(scalar: crate::Scalar) -> Result, Err prefix: "b", full: "bool", }, + Sk::AbstractInt | Sk::AbstractFloat => { + return Err(Error::UnsupportedScalar(scalar)); + } }) } diff --git a/naga/src/back/hlsl/conv.rs b/naga/src/back/hlsl/conv.rs index da17c35704..b6918ddc42 100644 --- a/naga/src/back/hlsl/conv.rs +++ b/naga/src/back/hlsl/conv.rs @@ -10,7 +10,7 @@ impl crate::ScalarKind { Self::Float => "asfloat", Self::Sint => "asint", Self::Uint => "asuint", - Self::Bool => unreachable!(), + Self::Bool | Self::AbstractInt | Self::AbstractFloat => unreachable!(), } } } @@ -30,6 +30,9 @@ impl crate::Scalar { _ => Err(Error::UnsupportedScalar(self)), }, crate::ScalarKind::Bool => Ok("bool"), + crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => { + Err(Error::UnsupportedScalar(self)) + } } } } diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 17154c3cd5..de226af87b 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -338,6 +338,10 @@ impl crate::Scalar { kind: Sk::Bool, width: _, } => "bool", + Self { + kind: Sk::AbstractInt | Sk::AbstractFloat, + width: _, + } => unreachable!(), } } } diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index df6ecd00ff..84f8581521 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -1175,7 +1175,7 @@ impl<'w> BlockContext<'w> { let op = match src_scalar.kind { Sk::Sint | Sk::Uint => spirv::Op::INotEqual, Sk::Float => spirv::Op::FUnordNotEqual, - Sk::Bool => unreachable!(), + Sk::Bool | Sk::AbstractInt | Sk::AbstractFloat => unreachable!(), }; let zero_scalar_id = self.writer.get_constant_scalar_with(0, src_scalar)?; diff --git a/naga/src/back/spv/image.rs b/naga/src/back/spv/image.rs index fb9d44e7f0..460c906d47 100644 --- a/naga/src/back/spv/image.rs +++ b/naga/src/back/spv/image.rs @@ -334,6 +334,10 @@ impl<'w> BlockContext<'w> { (_, crate::ScalarKind::Bool | crate::ScalarKind::Float) => { unreachable!("we don't allow bool or float for array index") } + (crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat, _) + | (_, crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat) => { + unreachable!("abstract types should never reach backends") + } }; let reconciled_array_index_id = if let Some(cast) = cast { let component_ty_id = self.get_type_id(LookupType::Local(LocalType::Value { diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index ef0532b2ea..48fc64bf24 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -824,6 +824,9 @@ impl Writer { Instruction::type_float(id, bits) } Sk::Bool => Instruction::type_bool(id), + Sk::AbstractInt | Sk::AbstractFloat => { + unreachable!("abstract types should never reach the backend"); + } } } @@ -1591,6 +1594,11 @@ impl Writer { | crate::TypeInner::Vector { scalar, .. } => match scalar.kind { Sk::Uint | Sk::Sint | Sk::Bool => true, Sk::Float => false, + Sk::AbstractInt | Sk::AbstractFloat => { + return Err(Error::Validation( + "Abstract types should not appear in IR presented to backends", + )) + } }, _ => false, }; diff --git a/naga/src/front/glsl/types.rs b/naga/src/front/glsl/types.rs index f0a2705ad2..e87d76fffc 100644 --- a/naga/src/front/glsl/types.rs +++ b/naga/src/front/glsl/types.rs @@ -205,7 +205,7 @@ pub const fn type_power(scalar: Scalar) -> Option { ScalarKind::Uint => 1, ScalarKind::Float if scalar.width == 4 => 2, ScalarKind::Float => 3, - ScalarKind::Bool => return None, + ScalarKind::Bool | ScalarKind::AbstractInt | ScalarKind::AbstractFloat => return None, }) } diff --git a/naga/src/front/wgsl/to_wgsl.rs b/naga/src/front/wgsl/to_wgsl.rs index cdfa1f0b1f..c8331ace09 100644 --- a/naga/src/front/wgsl/to_wgsl.rs +++ b/naga/src/front/wgsl/to_wgsl.rs @@ -140,6 +140,8 @@ impl crate::Scalar { crate::ScalarKind::Uint => "u", crate::ScalarKind::Float => "f", crate::ScalarKind::Bool => return "bool".to_string(), + crate::ScalarKind::AbstractInt => return "{AbstractInt}".to_string(), + crate::ScalarKind::AbstractFloat => return "{AbstractFloat}".to_string(), }; format!("{}{}", prefix, self.width * 8) } diff --git a/naga/src/lib.rs b/naga/src/lib.rs index e140ad6aef..e45f463bd9 100644 --- a/naga/src/lib.rs +++ b/naga/src/lib.rs @@ -470,6 +470,16 @@ pub enum ScalarKind { Float, /// Boolean type. Bool, + + /// WGSL abstract integer type. + /// + /// These are forbidden by validation, and should never reach backends. + AbstractInt, + + /// Abstract floating-point type. + /// + /// These are forbidden by validation, and should never reach backends. + AbstractFloat, } /// Characteristics of a scalar type. diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 4f2f5c705d..687527049e 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -71,7 +71,11 @@ impl From for super::ScalarKind { impl super::ScalarKind { pub const fn is_numeric(self) -> bool { match self { - crate::ScalarKind::Sint | crate::ScalarKind::Uint | crate::ScalarKind::Float => true, + crate::ScalarKind::Sint + | crate::ScalarKind::Uint + | crate::ScalarKind::Float + | crate::ScalarKind::AbstractInt + | crate::ScalarKind::AbstractFloat => true, crate::ScalarKind::Bool => false, } } diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 1f57c55441..ba427dfda2 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -670,7 +670,11 @@ impl super::Validator { let good = match op { Bo::Add | Bo::Subtract => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { - Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Uint + | Sk::Sint + | Sk::Float + | Sk::AbstractInt + | Sk::AbstractFloat => left_inner == right_inner, Sk::Bool => false, }, Ti::Matrix { .. } => left_inner == right_inner, @@ -678,14 +682,24 @@ impl super::Validator { }, Bo::Divide | Bo::Modulo => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { - Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Uint + | Sk::Sint + | Sk::Float + | Sk::AbstractInt + | Sk::AbstractFloat => left_inner == right_inner, Sk::Bool => false, }, _ => false, }, Bo::Multiply => { let kind_allowed = match left_inner.scalar_kind() { - Some(Sk::Uint | Sk::Sint | Sk::Float) => true, + Some( + Sk::Uint + | Sk::Sint + | Sk::Float + | Sk::AbstractInt + | Sk::AbstractFloat, + ) => true, Some(Sk::Bool) | None => false, }; let types_match = match (left_inner, right_inner) { @@ -762,7 +776,11 @@ impl super::Validator { Bo::Less | Bo::LessEqual | Bo::Greater | Bo::GreaterEqual => { match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { - Sk::Uint | Sk::Sint | Sk::Float => left_inner == right_inner, + Sk::Uint + | Sk::Sint + | Sk::Float + | Sk::AbstractInt + | Sk::AbstractFloat => left_inner == right_inner, Sk::Bool => false, }, ref other => { @@ -784,8 +802,10 @@ impl super::Validator { }, Bo::And | Bo::InclusiveOr => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { - Sk::Bool | Sk::Sint | Sk::Uint => left_inner == right_inner, - Sk::Float => false, + Sk::Bool | Sk::Sint | Sk::Uint | Sk::AbstractInt => { + left_inner == right_inner + } + Sk::Float | Sk::AbstractFloat => false, }, ref other => { log::error!("Op {:?} left type {:?}", op, other); @@ -794,8 +814,8 @@ impl super::Validator { }, Bo::ExclusiveOr => match *left_inner { Ti::Scalar(scalar) | Ti::Vector { scalar, .. } => match scalar.kind { - Sk::Sint | Sk::Uint => left_inner == right_inner, - Sk::Bool | Sk::Float => false, + Sk::Sint | Sk::Uint | Sk::AbstractInt => left_inner == right_inner, + Sk::Bool | Sk::Float | Sk::AbstractFloat => false, }, ref other => { log::error!("Op {:?} left type {:?}", op, other); @@ -823,8 +843,10 @@ impl super::Validator { } }; match base_scalar.kind { - Sk::Sint | Sk::Uint => base_size.is_ok() && base_size == shift_size, - Sk::Float | Sk::Bool => false, + Sk::Sint | Sk::Uint | Sk::AbstractInt => { + base_size.is_ok() && base_size == shift_size + } + Sk::Float | Sk::AbstractFloat | Sk::Bool => false, } } }; diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index 53462fe801..1e3e03fe19 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -143,6 +143,9 @@ pub enum WidthError { #[error("64-bit integers are not yet supported")] Unsupported64Bit, + + #[error("Abstract types may only appear in constant expressions")] + Abstract, } // Only makes sense if `flags.contains(HOST_SHAREABLE)` @@ -248,6 +251,9 @@ impl super::Validator { } scalar.width == 4 } + crate::ScalarKind::AbstractInt | crate::ScalarKind::AbstractFloat => { + return Err(WidthError::Abstract); + } }; if good { Ok(()) @@ -325,7 +331,10 @@ impl super::Validator { } Ti::Atomic(crate::Scalar { kind, width }) => { let good = match kind { - crate::ScalarKind::Bool | crate::ScalarKind::Float => false, + crate::ScalarKind::Bool + | crate::ScalarKind::Float + | crate::ScalarKind::AbstractInt + | crate::ScalarKind::AbstractFloat => false, crate::ScalarKind::Sint | crate::ScalarKind::Uint => width == 4, }; if !good { diff --git a/wgpu-core/src/validation.rs b/wgpu-core/src/validation.rs index 687f75e150..6eb52376b8 100644 --- a/wgpu-core/src/validation.rs +++ b/wgpu-core/src/validation.rs @@ -560,7 +560,9 @@ impl Resource { } naga::ScalarKind::Sint => wgt::TextureSampleType::Sint, naga::ScalarKind::Uint => wgt::TextureSampleType::Uint, - naga::ScalarKind::Bool => unreachable!(), + naga::ScalarKind::AbstractInt + | naga::ScalarKind::AbstractFloat + | naga::ScalarKind::Bool => unreachable!(), }, view_dimension, multisampled: multi,