diff --git a/src/arena.rs b/src/arena.rs index 99d977b2ff..1e7b659371 100644 --- a/src/arena.rs +++ b/src/arena.rs @@ -15,6 +15,15 @@ pub struct BadHandle { pub index: usize, } +impl BadHandle { + fn new(handle: Handle) -> Self { + Self { + kind: std::any::type_name::(), + index: handle.index(), + } + } +} + /// A strongly typed reference to an arena item. /// /// A `Handle` value can be used as an index into an [`Arena`] or [`UniqueArena`]. @@ -123,6 +132,35 @@ pub struct Range { marker: PhantomData, } +impl Range { + pub(crate) const fn erase_type(self) -> Range<()> { + let Self { inner, marker: _ } = self; + Range { + inner, + marker: PhantomData, + } + } +} + +// NOTE: Keep this diagnostic in sync with that of [`BadHandle`]. +#[derive(Clone, Debug, thiserror::Error)] +#[error("Handle range {range:?} of {kind} is either not present, or inaccessible yet")] +pub struct BadRangeError { + // This error is used for many `Handle` types, but there's no point in making this generic, so + // we just flatten them all to `Handle<()>` here. + kind: &'static str, + range: Range<()>, +} + +impl BadRangeError { + pub fn new(range: Range) -> Self { + Self { + kind: std::any::type_name::(), + range: range.erase_type(), + } + } +} + impl Clone for Range { fn clone(&self) -> Self { Range { @@ -282,10 +320,9 @@ impl Arena { } pub fn try_get(&self, handle: Handle) -> Result<&T, BadHandle> { - self.data.get(handle.index()).ok_or_else(|| BadHandle { - kind: std::any::type_name::(), - index: handle.index(), - }) + self.data + .get(handle.index()) + .ok_or_else(|| BadHandle::new(handle)) } /// Get a mutable reference to an element in the arena. @@ -320,6 +357,31 @@ impl Arena { Span::default() } } + + /// Assert that `handle` is valid for this arena. + pub fn check_contains_handle(&self, handle: Handle) -> Result<(), BadHandle> { + if handle.index() < self.data.len() { + Ok(()) + } else { + Err(BadHandle::new(handle)) + } + } + + /// Assert that `range` is valid for this arena. + pub fn check_contains_range(&self, range: &Range) -> Result<(), BadRangeError> { + // Since `range.inner` is a `Range`, we only need to + // check that the start precedes the end, and that the end is + // in range. + if range.inner.start > range.inner.end + || self + .check_contains_handle(Handle::new(range.inner.end.try_into().unwrap())) + .is_err() + { + Err(BadRangeError::new(range.clone())) + } else { + Ok(()) + } + } } #[cfg(feature = "deserialize")] @@ -540,10 +602,18 @@ impl UniqueArena { /// Return this arena's value at `handle`, if that is a valid handle. pub fn get_handle(&self, handle: Handle) -> Result<&T, BadHandle> { - self.set.get_index(handle.index()).ok_or_else(|| BadHandle { - kind: std::any::type_name::(), - index: handle.index(), - }) + self.set + .get_index(handle.index()) + .ok_or_else(|| BadHandle::new(handle)) + } + + /// Assert that `handle` is valid for this arena. + pub fn check_contains_handle(&self, handle: Handle) -> Result<(), BadHandle> { + if handle.index() < self.set.len() { + Ok(()) + } else { + Err(BadHandle::new(handle)) + } } } diff --git a/src/back/hlsl/conv.rs b/src/back/hlsl/conv.rs index 3feb80bb68..525930e1cf 100644 --- a/src/back/hlsl/conv.rs +++ b/src/back/hlsl/conv.rs @@ -40,12 +40,12 @@ impl crate::TypeInner { } } - pub(super) fn try_size_hlsl( + pub(super) fn size_hlsl( &self, types: &crate::UniqueArena, constants: &crate::Arena, - ) -> Result { - Ok(match *self { + ) -> u32 { + match *self { Self::Matrix { columns, rows, @@ -58,17 +58,16 @@ impl crate::TypeInner { Self::Array { base, size, stride } => { let count = match size { crate::ArraySize::Constant(handle) => { - let constant = constants.try_get(handle)?; - constant.to_array_length().unwrap_or(1) + constants[handle].to_array_length().unwrap_or(1) } // A dynamically-sized array has to have at least one element crate::ArraySize::Dynamic => 1, }; - let last_el_size = types[base].inner.try_size_hlsl(types, constants)?; + let last_el_size = types[base].inner.size_hlsl(types, constants); ((count - 1) * stride) + last_el_size } - _ => self.try_size(constants)?, - }) + _ => self.size(constants), + } } /// Used to generate the name of the wrapped type constructor diff --git a/src/back/hlsl/writer.rs b/src/back/hlsl/writer.rs index e1325e5abf..bf841a121c 100644 --- a/src/back/hlsl/writer.rs +++ b/src/back/hlsl/writer.rs @@ -829,10 +829,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { } } let ty_inner = &module.types[member.ty].inner; - last_offset = member.offset - + ty_inner - .try_size_hlsl(&module.types, &module.constants) - .unwrap(); + last_offset = member.offset + ty_inner.size_hlsl(&module.types, &module.constants); // The indentation is only for readability write!(self.out, "{}", back::INDENT)?; diff --git a/src/proc/layouter.rs b/src/proc/layouter.rs index 0c3a00db15..db07f261a4 100644 --- a/src/proc/layouter.rs +++ b/src/proc/layouter.rs @@ -1,4 +1,4 @@ -use crate::arena::{Arena, BadHandle, Handle, UniqueArena}; +use crate::arena::{Arena, Handle, UniqueArena}; use std::{fmt::Display, num::NonZeroU32, ops}; /// A newtype struct where its only valid values are powers of 2 @@ -130,8 +130,6 @@ pub enum LayoutErrorInner { InvalidStructMemberType(u32, Handle), #[error("Type width must be a power of two")] NonPowerOfTwoWidth, - #[error("Array size is a bad handle")] - BadHandle(#[from] BadHandle), } #[derive(Clone, Copy, Debug, PartialEq, thiserror::Error)] @@ -175,10 +173,7 @@ impl Layouter { use crate::TypeInner as Ti; for (ty_handle, ty) in types.iter().skip(self.layouts.len()) { - let size = ty - .inner - .try_size(constants) - .map_err(|error| LayoutErrorInner::BadHandle(error).with(ty_handle))?; + let size = ty.inner.size(constants); let layout = match ty.inner { Ti::Scalar { width, .. } | Ti::Atomic { width, .. } => { let alignment = Alignment::new(width as u32) diff --git a/src/proc/mod.rs b/src/proc/mod.rs index a5731de896..c718c33b24 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -97,11 +97,9 @@ impl super::TypeInner { } } - pub fn try_size( - &self, - constants: &super::Arena, - ) -> Result { - Ok(match *self { + /// Get the size of this type. + pub fn size(&self, constants: &super::Arena) -> u32 { + match *self { Self::Scalar { kind: _, width } | Self::Atomic { kind: _, width } => width as u32, Self::Vector { size, @@ -122,8 +120,7 @@ impl super::TypeInner { } => { let count = match size { super::ArraySize::Constant(handle) => { - let constant = constants.try_get(handle)?; - constant.to_array_length().unwrap_or(1) + constants[handle].to_array_length().unwrap_or(1) } // A dynamically-sized array has to have at least one element super::ArraySize::Dynamic => 1, @@ -132,13 +129,7 @@ impl super::TypeInner { } Self::Struct { span, .. } => span, Self::Image { .. } | Self::Sampler { .. } | Self::BindingArray { .. } => 0, - }) - } - - /// Get the size of this type. Panics if the `constants` doesn't contain - /// a referenced handle. This may not happen in a properly validated IR module. - pub fn size(&self, constants: &super::Arena) -> u32 { - self.try_size(constants).unwrap() + } } /// Return the canonical form of `self`, or `None` if it's already in diff --git a/src/proc/typifier.rs b/src/proc/typifier.rs index 9df538cc2b..47ad05e06c 100644 --- a/src/proc/typifier.rs +++ b/src/proc/typifier.rs @@ -1,4 +1,4 @@ -use crate::arena::{Arena, BadHandle, Handle, UniqueArena}; +use crate::arena::{Arena, Handle, UniqueArena}; use thiserror::Error; @@ -162,8 +162,6 @@ impl crate::ConstantInner { #[derive(Clone, Debug, Error, PartialEq)] pub enum ResolveError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("Index {index} is out of bounds for expression {expr:?}")] OutOfBoundsIndex { expr: Handle, @@ -195,8 +193,6 @@ pub enum ResolveError { IncompatibleOperands(String), #[error("Function argument {0} doesn't exist")] FunctionArgumentNotFound(u32), - #[error("Expression {0:?} depends on expressions that follow")] - ExpressionForwardDependency(Handle), } pub struct ResolveContext<'a> { @@ -403,20 +399,15 @@ impl<'a> ResolveContext<'a> { } } } - crate::Expression::Constant(h) => { - let constant = self.constants.try_get(h)?; - match constant.inner { - crate::ConstantInner::Scalar { width, ref value } => { - TypeResolution::Value(Ti::Scalar { - kind: value.scalar_kind(), - width, - }) - } - crate::ConstantInner::Composite { ty, components: _ } => { - TypeResolution::Handle(ty) - } + crate::Expression::Constant(h) => match self.constants[h].inner { + crate::ConstantInner::Scalar { width, ref value } => { + TypeResolution::Value(Ti::Scalar { + kind: value.scalar_kind(), + width, + }) } - } + crate::ConstantInner::Composite { ty, components: _ } => TypeResolution::Handle(ty), + }, crate::Expression::Splat { size, value } => match *past(value)?.inner_with(types) { Ti::Scalar { kind, width } => { TypeResolution::Value(Ti::Vector { size, kind, width }) @@ -450,7 +441,7 @@ impl<'a> ResolveContext<'a> { TypeResolution::Handle(arg.ty) } crate::Expression::GlobalVariable(h) => { - let var = self.global_vars.try_get(h)?; + let var = &self.global_vars[h]; if var.space == crate::AddressSpace::Handle { TypeResolution::Handle(var.ty) } else { @@ -461,7 +452,7 @@ impl<'a> ResolveContext<'a> { } } crate::Expression::LocalVariable(h) => { - let var = self.local_vars.try_get(h)?; + let var = &self.local_vars[h]; TypeResolution::Value(Ti::Pointer { base: var.ty, space: crate::AddressSpace::Function, diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index ec8e9d96e5..eb6c1fc4a7 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -10,7 +10,7 @@ use super::{CallError, ExpressionError, FunctionError, ModuleInfo, ShaderStages, use crate::span::{AddSpan as _, WithSpan}; use crate::{ arena::{Arena, Handle}, - proc::{ResolveContext, ResolveError, TypeResolution}, + proc::{ResolveContext, TypeResolution}, }; use std::ops; @@ -706,12 +706,7 @@ impl FunctionInfo { }, }; - let ty = resolve_context.resolve(expression, |h| { - self.expressions - .get(h.index()) - .map(|ei| &ei.ty) - .ok_or(ResolveError::ExpressionForwardDependency(h)) - })?; + let ty = resolve_context.resolve(expression, |h| Ok(&self[h].ty))?; self.expressions[handle.index()] = ExpressionInfo { uniformity, ref_count: 0, diff --git a/src/valid/compose.rs b/src/valid/compose.rs index 6e5c499223..e77d538255 100644 --- a/src/valid/compose.rs +++ b/src/valid/compose.rs @@ -4,13 +4,11 @@ use crate::{ proc::TypeResolution, }; -use crate::arena::{BadHandle, Handle}; +use crate::arena::Handle; #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum ComposeError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("Composing of type {0:?} can't be done")] Type(Handle), #[error("Composing expects {expected} components but {given} were given")] @@ -28,8 +26,7 @@ pub fn validate_compose( ) -> Result<(), ComposeError> { use crate::TypeInner as Ti; - let self_ty = type_arena.get_handle(self_ty_handle)?; - match self_ty.inner { + match type_arena[self_ty_handle].inner { // vectors are composed from scalars or other vectors Ti::Vector { size, kind, width } => { let mut total = 0; diff --git a/src/valid/expression.rs b/src/valid/expression.rs index a3afc0535a..14f52fb93c 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1,3 +1,6 @@ +#[cfg(feature = "validate")] +use std::ops::Index; + #[cfg(feature = "validate")] use super::{ compose::validate_compose, validate_atomic_compare_exchange_struct, FunctionInfo, ShaderStages, @@ -7,7 +10,7 @@ use super::{ use crate::arena::UniqueArena; use crate::{ - arena::{BadHandle, Handle}, + arena::Handle, proc::{IndexableLengthError, ResolveError}, }; @@ -18,10 +21,6 @@ pub enum ExpressionError { DoesntExist, #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")] NotInScope, - #[error("Depends on {0:?}, which has not been processed yet")] - ForwardDependency(Handle), - #[error(transparent)] - BadDependency(#[from] BadHandle), #[error("Base type {0:?} is not compatible with this expression")] InvalidBaseType(Handle), #[error("Accessing with index {0:?} can't be done")] @@ -132,15 +131,19 @@ struct ExpressionTypeResolver<'a> { } #[cfg(feature = "validate")] -impl<'a> ExpressionTypeResolver<'a> { - fn resolve( - &self, - handle: Handle, - ) -> Result<&'a crate::TypeInner, ExpressionError> { +impl<'a> Index> for ExpressionTypeResolver<'a> { + type Output = crate::TypeInner; + + #[allow(clippy::panic)] + fn index(&self, handle: Handle) -> &Self::Output { if handle < self.root { - Ok(self.info[handle].ty.inner_with(self.types)) + self.info[handle].ty.inner_with(self.types) } else { - Err(ExpressionError::ForwardDependency(handle)) + // `Validator::validate_module_handles` should have caught this. + panic!( + "Depends on {:?}, which has not been processed yet", + self.root + ) } } } @@ -166,7 +169,7 @@ impl super::Validator { let stages = match *expression { E::Access { base, index } => { - let base_type = resolver.resolve(base)?; + let base_type = &resolver[base]; // See the documentation for `Expression::Access`. let dynamic_indexing_restricted = match *base_type { Ti::Vector { .. } => false, @@ -179,7 +182,7 @@ impl super::Validator { return Err(ExpressionError::InvalidBaseType(base)); } }; - match *resolver.resolve(index)? { + match resolver[index] { //TODO: only allow one of these Ti::Scalar { kind: Sk::Sint | Sk::Uint, @@ -257,7 +260,7 @@ impl super::Validator { Ok(limit) } - let limit = resolve_index_limit(module, base, resolver.resolve(base)?, true)?; + let limit = resolve_index_limit(module, base, &resolver[base], true)?; if index >= limit { return Err(ExpressionError::IndexOutOfBounds( base, @@ -266,11 +269,8 @@ impl super::Validator { } ShaderStages::all() } - E::Constant(handle) => { - let _ = module.constants.try_get(handle)?; - ShaderStages::all() - } - E::Splat { size: _, value } => match *resolver.resolve(value)? { + E::Constant(_handle) => ShaderStages::all(), + E::Splat { size: _, value } => match resolver[value] { Ti::Scalar { .. } => ShaderStages::all(), ref other => { log::error!("Splat scalar type {:?}", other); @@ -282,7 +282,7 @@ impl super::Validator { vector, pattern, } => { - let vec_size = match *resolver.resolve(vector)? { + let vec_size = match resolver[vector] { Ti::Vector { size: vec_size, .. } => vec_size, ref other => { log::error!("Swizzle vector type {:?}", other); @@ -297,11 +297,6 @@ impl super::Validator { ShaderStages::all() } E::Compose { ref components, ty } => { - for &handle in components { - if handle >= root { - return Err(ExpressionError::ForwardDependency(handle)); - } - } validate_compose( ty, &module.constants, @@ -316,16 +311,10 @@ impl super::Validator { } ShaderStages::all() } - E::GlobalVariable(handle) => { - let _ = module.global_variables.try_get(handle)?; - ShaderStages::all() - } - E::LocalVariable(handle) => { - let _ = function.local_variables.try_get(handle)?; - ShaderStages::all() - } + E::GlobalVariable(_handle) => ShaderStages::all(), + E::LocalVariable(_handle) => ShaderStages::all(), E::Load { pointer } => { - match *resolver.resolve(pointer)? { + match resolver[pointer] { Ti::Pointer { base, .. } if self.types[base.index()] .flags @@ -368,7 +357,7 @@ impl super::Validator { return Err(ExpressionError::InvalidImageArrayIndex); } if let Some(expr) = array_index { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Sint, width: _, @@ -408,7 +397,7 @@ impl super::Validator { crate::ImageDimension::D2 => 2, crate::ImageDimension::D3 | crate::ImageDimension::Cube => 3, }; - match *resolver.resolve(coordinate)? { + match resolver[coordinate] { Ti::Scalar { kind: Sk::Float, .. } if num_components == 1 => {} @@ -446,7 +435,7 @@ impl super::Validator { // check depth reference type if let Some(expr) = depth_ref { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Float, .. } => {} @@ -483,7 +472,7 @@ impl super::Validator { crate::SampleLevel::Auto => ShaderStages::FRAGMENT, crate::SampleLevel::Zero => ShaderStages::all(), crate::SampleLevel::Exact(expr) => { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Float, .. } => {} @@ -492,7 +481,7 @@ impl super::Validator { ShaderStages::all() } crate::SampleLevel::Bias(expr) => { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Float, .. } => {} @@ -501,7 +490,7 @@ impl super::Validator { ShaderStages::all() } crate::SampleLevel::Gradient { x, y } => { - match *resolver.resolve(x)? { + match resolver[x] { Ti::Scalar { kind: Sk::Float, .. } if num_components == 1 => {} @@ -514,7 +503,7 @@ impl super::Validator { return Err(ExpressionError::InvalidSampleLevelGradientType(dim, x)) } } - match *resolver.resolve(y)? { + match resolver[y] { Ti::Scalar { kind: Sk::Float, .. } if num_components == 1 => {} @@ -545,7 +534,7 @@ impl super::Validator { arrayed, dim, } => { - match resolver.resolve(coordinate)?.image_storage_coordinates() { + match resolver[coordinate].image_storage_coordinates() { Some(coord_dim) if coord_dim == dim => {} _ => { return Err(ExpressionError::InvalidImageCoordinateType( @@ -557,7 +546,7 @@ impl super::Validator { return Err(ExpressionError::InvalidImageArrayIndex); } if let Some(expr) = array_index { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Sint, width: _, @@ -569,7 +558,7 @@ impl super::Validator { match (sample, class.is_multisampled()) { (None, false) => {} (Some(sample), true) => { - if resolver.resolve(sample)?.scalar_kind() != Some(Sk::Sint) { + if resolver[sample].scalar_kind() != Some(Sk::Sint) { return Err(ExpressionError::InvalidImageOtherIndexType( sample, )); @@ -583,7 +572,7 @@ impl super::Validator { match (level, class.is_mipmapped()) { (None, false) => {} (Some(level), true) => { - if resolver.resolve(level)?.scalar_kind() != Some(Sk::Sint) { + if resolver[level].scalar_kind() != Some(Sk::Sint) { return Err(ExpressionError::InvalidImageOtherIndexType(level)); } } @@ -617,7 +606,7 @@ impl super::Validator { } E::Unary { op, expr } => { use crate::UnaryOperator as Uo; - let inner = resolver.resolve(expr)?; + let inner = &resolver[expr]; match (op, inner.scalar_kind()) { (_, Some(Sk::Sint | Sk::Bool)) //TODO: restrict Negate for bools? @@ -632,8 +621,8 @@ impl super::Validator { } E::Binary { op, left, right } => { use crate::BinaryOperator as Bo; - let left_inner = resolver.resolve(left)?; - let right_inner = resolver.resolve(right)?; + let left_inner = &resolver[left]; + let right_inner = &resolver[right]; let good = match op { Bo::Add | Bo::Subtract => match *left_inner { Ti::Scalar { kind, .. } | Ti::Vector { kind, .. } => match kind { @@ -814,9 +803,9 @@ impl super::Validator { accept, reject, } => { - let accept_inner = resolver.resolve(accept)?; - let reject_inner = resolver.resolve(reject)?; - let condition_good = match *resolver.resolve(condition)? { + let accept_inner = &resolver[accept]; + let reject_inner = &resolver[reject]; + let condition_good = match resolver[condition] { Ti::Scalar { kind: Sk::Bool, width: _, @@ -846,7 +835,7 @@ impl super::Validator { ShaderStages::all() } E::Derivative { axis: _, expr } => { - match *resolver.resolve(expr)? { + match resolver[expr] { Ti::Scalar { kind: Sk::Float, .. } @@ -859,7 +848,7 @@ impl super::Validator { } E::Relational { fun, argument } => { use crate::RelationalFunction as Rf; - let argument_inner = resolver.resolve(argument)?; + let argument_inner = &resolver[argument]; match fun { Rf::All | Rf::Any => match *argument_inner { Ti::Vector { kind: Sk::Bool, .. } => {} @@ -892,10 +881,11 @@ impl super::Validator { } => { use crate::MathFunction as Mf; - let arg_ty = resolver.resolve(arg)?; - let arg1_ty = arg1.map(|expr| resolver.resolve(expr)).transpose()?; - let arg2_ty = arg2.map(|expr| resolver.resolve(expr)).transpose()?; - let arg3_ty = arg3.map(|expr| resolver.resolve(expr)).transpose()?; + let resolve = |arg| &resolver[arg]; + let arg_ty = resolve(arg); + let arg1_ty = arg1.map(resolve); + let arg2_ty = arg2.map(resolve); + let arg3_ty = arg3.map(resolve); match fun { Mf::Abs => { if arg1_ty.is_some() | arg2_ty.is_some() | arg3_ty.is_some() { @@ -1379,7 +1369,7 @@ impl super::Validator { kind, convert, } => { - let base_width = match *resolver.resolve(expr)? { + let base_width = match resolver[expr] { crate::TypeInner::Scalar { width, .. } | crate::TypeInner::Vector { width, .. } | crate::TypeInner::Matrix { width, .. } => width, @@ -1416,9 +1406,9 @@ impl super::Validator { } ShaderStages::all() } - E::ArrayLength(expr) => match *resolver.resolve(expr)? { + E::ArrayLength(expr) => match resolver[expr] { Ti::Pointer { base, .. } => { - let base_ty = resolver.types.get_handle(base)?; + let base_ty = &resolver.types[base]; if let Ti::Array { size: crate::ArraySize::Dynamic, .. diff --git a/src/valid/function.rs b/src/valid/function.rs index 0f0a7b89f5..3c555491c3 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -1,6 +1,6 @@ +use crate::arena::Handle; #[cfg(feature = "validate")] use crate::arena::{Arena, UniqueArena}; -use crate::arena::{BadHandle, Handle}; #[cfg(feature = "validate")] use super::validate_atomic_compare_exchange_struct; @@ -19,8 +19,6 @@ use bit_set::BitSet; #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum CallError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("The callee is declared after the caller")] ForwardDeclaredFunction, #[error("Argument {index} expression is invalid")] @@ -69,8 +67,6 @@ pub enum LocalVariableError { #[derive(Clone, Debug, thiserror::Error)] #[cfg_attr(test, derive(PartialEq))] pub enum FunctionError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("Expression {handle:?} is invalid")] Expression { handle: Handle, @@ -203,11 +199,8 @@ impl<'a> BlockContext<'a> { BlockContext { abilities, ..*self } } - fn get_expression( - &self, - handle: Handle, - ) -> Result<&'a crate::Expression, FunctionError> { - Ok(self.expressions.try_get(handle)?) + fn get_expression(&self, handle: Handle) -> &'a crate::Expression { + &self.expressions[handle] } fn resolve_type_impl( @@ -257,11 +250,7 @@ impl super::Validator { result: Option>, context: &BlockContext, ) -> Result> { - let fun = context - .functions - .try_get(function) - .map_err(CallError::BadHandle) - .map_err(WithSpan::new)?; + let fun = &context.functions[function]; if fun.arguments.len() != arguments.len() { return Err(CallError::ArgumentCount { required: fun.arguments.len(), @@ -689,14 +678,14 @@ impl super::Validator { } => { //Note: this code uses a lot of `FunctionError::InvalidImageStore`, // and could probably be refactored. - let var = match *context.get_expression(image).map_err(|e| e.with_span())? { + let var = match *context.get_expression(image) { crate::Expression::GlobalVariable(var_handle) => { &context.global_vars[var_handle] } // We're looking at a binding index situation, so punch through the index and look at the global behind it. crate::Expression::Access { base, .. } | crate::Expression::AccessIndex { base, .. } => { - match *context.get_expression(base).map_err(|e| e.with_span())? { + match *context.get_expression(base) { crate::Expression::GlobalVariable(var_handle) => { &context.global_vars[var_handle] } @@ -899,10 +888,7 @@ impl super::Validator { #[cfg(feature = "validate")] for (index, argument) in fun.arguments.iter().enumerate() { - let ty = module.types.get_handle(argument.ty).map_err(|err| { - FunctionError::from(err).with_span_handle(argument.ty, &module.types) - })?; - match ty.inner.pointer_space() { + match module.types[argument.ty].inner.pointer_space() { Some( crate::AddressSpace::Private | crate::AddressSpace::Function diff --git a/src/valid/handles.rs b/src/valid/handles.rs new file mode 100644 index 0000000000..5b3375a873 --- /dev/null +++ b/src/valid/handles.rs @@ -0,0 +1,616 @@ +//! Implementation of [`super::Validator::validate_module_handles`]. + +use crate::{ + arena::{BadHandle, BadRangeError}, + Handle, +}; + +#[cfg(feature = "validate")] +use crate::{Arena, UniqueArena}; + +#[cfg(feature = "validate")] +use super::{TypeError, ValidationError}; + +#[cfg(feature = "validate")] +use std::{convert::TryInto, hash::Hash, num::NonZeroU32}; + +#[cfg(feature = "validate")] +impl super::Validator { + /// Validates that all handles within `module` are: + /// + /// * Valid, in the sense that they contain indices within each arena structure inside the + /// [`crate::Module`] type. + /// * No arena contents contain any items that have forward dependencies; that is, the value + /// associated with a handle only may contain references to handles in the same arena that + /// were constructed before it. + /// + /// By validating the above conditions, we free up subsequent logic to assume that handle + /// accesses are infallible. + /// + /// # Errors + /// + /// Errors returned by this method are intentionally sparse, for simplicity of implementation. + /// It is expected that only buggy frontends or fuzzers should ever emit IR that fails this + /// validation pass. + pub(super) fn validate_module_handles(module: &crate::Module) -> Result<(), ValidationError> { + let &crate::Module { + ref constants, + ref entry_points, + ref functions, + ref global_variables, + ref types, + } = module; + + // NOTE: Types being first is important. All other forms of validation depend on this. + for (this_handle, ty) in types.iter() { + let &crate::Type { + ref name, + ref inner, + } = ty; + + let validate_array_size = |size| { + match size { + crate::ArraySize::Constant(constant) => { + let &crate::Constant { + name: _, + specialization: _, + ref inner, + } = constants.try_get(constant)?; + if !matches!(inner, &crate::ConstantInner::Scalar { .. }) { + return Err(ValidationError::Type { + handle: this_handle, + name: name.clone().unwrap_or_default(), + source: TypeError::InvalidArraySizeConstant(constant), + }); + } + } + crate::ArraySize::Dynamic => (), + }; + Ok(this_handle) + }; + + match *inner { + crate::TypeInner::Scalar { .. } + | crate::TypeInner::Vector { .. } + | crate::TypeInner::Matrix { .. } + | crate::TypeInner::ValuePointer { .. } + | crate::TypeInner::Atomic { .. } + | crate::TypeInner::Image { .. } + | crate::TypeInner::Sampler { .. } => (), + crate::TypeInner::Pointer { base, space: _ } => { + this_handle.check_dep(base)?; + } + crate::TypeInner::Array { + base, + size, + stride: _, + } + | crate::TypeInner::BindingArray { base, size } => { + this_handle.check_dep(base)?; + validate_array_size(size)?; + } + crate::TypeInner::Struct { + ref members, + span: _, + } => { + this_handle.check_dep_iter(members.iter().map(|m| m.ty))?; + } + } + } + + let validate_type = |handle| Self::validate_type_handle(handle, types); + + for (this_handle, constant) in constants.iter() { + let &crate::Constant { + name: _, + specialization: _, + ref inner, + } = constant; + match *inner { + crate::ConstantInner::Scalar { .. } => (), + crate::ConstantInner::Composite { ty, ref components } => { + validate_type(ty)?; + this_handle.check_dep_iter(components.iter().copied())?; + } + } + } + + let validate_constant = |handle| Self::validate_constant_handle(handle, constants); + + for (_handle, global_variable) in global_variables.iter() { + let &crate::GlobalVariable { + name: _, + space: _, + binding: _, + ty, + init, + } = global_variable; + validate_type(ty)?; + if let Some(init_expr) = init { + validate_constant(init_expr)?; + } + } + + let validate_function = |function: &_| -> Result<_, InvalidHandleError> { + let &crate::Function { + name: _, + ref arguments, + ref result, + ref local_variables, + ref expressions, + ref named_expressions, + ref body, + } = function; + + for arg in arguments.iter() { + let &crate::FunctionArgument { + name: _, + ty, + binding: _, + } = arg; + validate_type(ty)?; + } + + if let &Some(crate::FunctionResult { ty, binding: _ }) = result { + validate_type(ty)?; + } + + for (_handle, local_variable) in local_variables.iter() { + let &crate::LocalVariable { name: _, ty, init } = local_variable; + validate_type(ty)?; + if let Some(init_constant) = init { + validate_constant(init_constant)?; + } + } + + for handle in named_expressions.keys().copied() { + Self::validate_expression_handle(handle, expressions)?; + } + + for handle_and_expr in expressions.iter() { + Self::validate_expression_handles( + handle_and_expr, + constants, + types, + local_variables, + global_variables, + functions, + )?; + } + + Self::validate_block_handles(body, expressions, functions)?; + + Ok(()) + }; + + for entry_point in entry_points.iter() { + validate_function(&entry_point.function)?; + } + + for (_function_handle, function) in functions.iter() { + validate_function(function)?; + } + + Ok(()) + } + + fn validate_type_handle( + handle: Handle, + types: &UniqueArena, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for_uniq(types).map(|_| ()) + } + + fn validate_constant_handle( + handle: Handle, + constants: &Arena, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for(constants).map(|_| ()) + } + + fn validate_expression_handle( + handle: Handle, + expressions: &Arena, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for(expressions).map(|_| ()) + } + + fn validate_function_handle( + handle: Handle, + functions: &Arena, + ) -> Result<(), InvalidHandleError> { + handle.check_valid_for(functions).map(|_| ()) + } + + fn validate_expression_handles( + (handle, expression): (Handle, &crate::Expression), + constants: &Arena, + types: &UniqueArena, + local_variables: &Arena, + global_variables: &Arena, + functions: &Arena, + ) -> Result<(), InvalidHandleError> { + let validate_constant = |handle| Self::validate_constant_handle(handle, constants); + let validate_type = |handle| Self::validate_type_handle(handle, types); + + match *expression { + crate::Expression::Access { base, index } => { + handle.check_dep(base)?.check_dep(index)?; + } + crate::Expression::AccessIndex { base, .. } => { + handle.check_dep(base)?; + } + crate::Expression::Constant(constant) => { + validate_constant(constant)?; + } + crate::Expression::Splat { value, .. } => { + handle.check_dep(value)?; + } + crate::Expression::Swizzle { vector, .. } => { + handle.check_dep(vector)?; + } + crate::Expression::Compose { ty, ref components } => { + validate_type(ty)?; + handle.check_dep_iter(components.iter().copied())?; + } + crate::Expression::FunctionArgument(_arg_idx) => (), + crate::Expression::GlobalVariable(global_variable) => { + global_variable.check_valid_for(global_variables)?; + } + crate::Expression::LocalVariable(local_variable) => { + local_variable.check_valid_for(local_variables)?; + } + crate::Expression::Load { pointer } => { + handle.check_dep(pointer)?; + } + crate::Expression::ImageSample { + image, + sampler, + gather: _, + coordinate, + array_index, + offset, + level, + depth_ref, + } => { + if let Some(offset) = offset { + validate_constant(offset)?; + } + + handle + .check_dep(image)? + .check_dep(sampler)? + .check_dep(coordinate)? + .check_dep_opt(array_index)?; + + match level { + crate::SampleLevel::Auto | crate::SampleLevel::Zero => (), + crate::SampleLevel::Exact(expr) => { + handle.check_dep(expr)?; + } + crate::SampleLevel::Bias(expr) => { + handle.check_dep(expr)?; + } + crate::SampleLevel::Gradient { x, y } => { + handle.check_dep(x)?.check_dep(y)?; + } + }; + + handle.check_dep_opt(depth_ref)?; + } + crate::Expression::ImageLoad { + image, + coordinate, + array_index, + sample, + level, + } => { + handle + .check_dep(image)? + .check_dep(coordinate)? + .check_dep_opt(array_index)? + .check_dep_opt(sample)? + .check_dep_opt(level)?; + } + crate::Expression::ImageQuery { image, query } => { + handle.check_dep(image)?; + match query { + crate::ImageQuery::Size { level } => { + handle.check_dep_opt(level)?; + } + crate::ImageQuery::NumLevels + | crate::ImageQuery::NumLayers + | crate::ImageQuery::NumSamples => (), + }; + } + crate::Expression::Unary { + op: _, + expr: operand, + } => { + handle.check_dep(operand)?; + } + crate::Expression::Binary { op: _, left, right } => { + handle.check_dep(left)?.check_dep(right)?; + } + crate::Expression::Select { + condition, + accept, + reject, + } => { + handle + .check_dep(condition)? + .check_dep(accept)? + .check_dep(reject)?; + } + crate::Expression::Derivative { + axis: _, + expr: argument, + } => { + handle.check_dep(argument)?; + } + crate::Expression::Relational { fun: _, argument } => { + handle.check_dep(argument)?; + } + crate::Expression::Math { + fun: _, + arg, + arg1, + arg2, + arg3, + } => { + handle + .check_dep(arg)? + .check_dep_opt(arg1)? + .check_dep_opt(arg2)? + .check_dep_opt(arg3)?; + } + crate::Expression::As { + expr: input, + kind: _, + convert: _, + } => { + handle.check_dep(input)?; + } + crate::Expression::CallResult(function) => { + Self::validate_function_handle(function, functions)?; + } + crate::Expression::AtomicResult { .. } => (), + crate::Expression::ArrayLength(array) => { + handle.check_dep(array)?; + } + } + Ok(()) + } + + fn validate_block_handles( + block: &crate::Block, + expressions: &Arena, + functions: &Arena, + ) -> Result<(), InvalidHandleError> { + let validate_block = |block| Self::validate_block_handles(block, expressions, functions); + let validate_expr = |handle| Self::validate_expression_handle(handle, expressions); + let validate_expr_opt = |handle_opt| { + if let Some(handle) = handle_opt { + validate_expr(handle)?; + } + Ok(()) + }; + + block.iter().try_for_each(|stmt| match *stmt { + crate::Statement::Emit(ref expr_range) => { + expr_range.check_valid_for(expressions)?; + Ok(()) + } + crate::Statement::Block(ref block) => { + validate_block(block)?; + Ok(()) + } + crate::Statement::If { + condition, + ref accept, + ref reject, + } => { + validate_expr(condition)?; + validate_block(accept)?; + validate_block(reject)?; + Ok(()) + } + crate::Statement::Switch { + selector, + ref cases, + } => { + validate_expr(selector)?; + for &crate::SwitchCase { + value: _, + ref body, + fall_through: _, + } in cases + { + validate_block(body)?; + } + Ok(()) + } + crate::Statement::Loop { + ref body, + ref continuing, + break_if, + } => { + validate_block(body)?; + validate_block(continuing)?; + validate_expr_opt(break_if)?; + Ok(()) + } + crate::Statement::Return { value } => validate_expr_opt(value), + crate::Statement::Store { pointer, value } => { + validate_expr(pointer)?; + validate_expr(value)?; + Ok(()) + } + crate::Statement::ImageStore { + image, + coordinate, + array_index, + value, + } => { + validate_expr(image)?; + validate_expr(coordinate)?; + validate_expr_opt(array_index)?; + validate_expr(value)?; + Ok(()) + } + crate::Statement::Atomic { + pointer, + fun, + value, + result, + } => { + validate_expr(pointer)?; + match fun { + crate::AtomicFunction::Add + | crate::AtomicFunction::Subtract + | crate::AtomicFunction::And + | crate::AtomicFunction::ExclusiveOr + | crate::AtomicFunction::InclusiveOr + | crate::AtomicFunction::Min + | crate::AtomicFunction::Max => (), + crate::AtomicFunction::Exchange { compare } => validate_expr_opt(compare)?, + }; + validate_expr(value)?; + validate_expr(result)?; + Ok(()) + } + crate::Statement::Call { + function, + ref arguments, + result, + } => { + Self::validate_function_handle(function, functions)?; + for arg in arguments.iter().copied() { + validate_expr(arg)?; + } + validate_expr_opt(result)?; + Ok(()) + } + crate::Statement::Break + | crate::Statement::Continue + | crate::Statement::Kill + | crate::Statement::Barrier(_) => Ok(()), + }) + } +} + +#[cfg(feature = "validate")] +impl From for ValidationError { + fn from(source: BadHandle) -> Self { + Self::InvalidHandle(source.into()) + } +} + +#[cfg(feature = "validate")] +impl From for ValidationError { + fn from(source: FwdDepError) -> Self { + Self::InvalidHandle(source.into()) + } +} + +#[cfg(feature = "validate")] +impl From for ValidationError { + fn from(source: BadRangeError) -> Self { + Self::InvalidHandle(source.into()) + } +} + +#[derive(Clone, Debug, thiserror::Error)] +pub enum InvalidHandleError { + #[error(transparent)] + BadHandle(#[from] BadHandle), + #[error(transparent)] + ForwardDependency(#[from] FwdDepError), + #[error(transparent)] + BadRange(#[from] BadRangeError), +} + +#[derive(Clone, Debug, thiserror::Error)] +#[error( + "{subject:?} of kind depends on {depends_on:?} of kind {depends_on_kind}, which has not been \ + processed yet" +)] +pub struct FwdDepError { + // This error is used for many `Handle` types, but there's no point in making this generic, so + // we just flatten them all to `Handle<()>` here. + subject: Handle<()>, + subject_kind: &'static str, + depends_on: Handle<()>, + depends_on_kind: &'static str, +} + +#[cfg(feature = "validate")] +impl Handle { + /// Check that `self` is valid within `arena` using [`Arena::check_contains_handle`]. + pub(self) fn check_valid_for(self, arena: &Arena) -> Result<(), InvalidHandleError> { + arena.check_contains_handle(self)?; + Ok(()) + } + + /// Check that `self` is valid within `arena` using [`UniqueArena::check_contains_handle`]. + pub(self) fn check_valid_for_uniq( + self, + arena: &UniqueArena, + ) -> Result<(), InvalidHandleError> + where + T: Eq + Hash, + { + arena.check_contains_handle(self)?; + Ok(()) + } + + /// Check that `depends_on` was constructed before `self` by comparing handle indices. + /// + /// If `self` is a valid handle (i.e., it has been validated using [`Self::check_valid_for`]) + /// and this function returns [`Ok`], then it may be assumed that `depends_on` is also valid. + /// In [`naga`](crate)'s current arena-based implementation, this is useful for validating + /// recursive definitions of arena-based values in linear time. + /// + /// # Errors + /// + /// If `depends_on`'s handle is from the same [`Arena`] as `self'`s, but not constructed earlier + /// than `self`'s, this function returns an error. + pub(self) fn check_dep(self, depends_on: Self) -> Result { + if depends_on < self { + Ok(self) + } else { + let erase_handle_type = |handle: Handle<_>| { + Handle::new(NonZeroU32::new(handle.index().try_into().unwrap()).unwrap()) + }; + Err(FwdDepError { + subject: erase_handle_type(self), + subject_kind: std::any::type_name::(), + depends_on: erase_handle_type(depends_on), + depends_on_kind: std::any::type_name::(), + }) + } + } + + /// Like [`Self::check_dep`], except for [`Option`]al handle values. + pub(self) fn check_dep_opt(self, depends_on: Option) -> Result { + self.check_dep_iter(depends_on.into_iter()) + } + + /// Like [`Self::check_dep`], except for [`Iterator`]s over handle values. + pub(self) fn check_dep_iter( + self, + depends_on: impl Iterator, + ) -> Result { + for handle in depends_on { + self.check_dep(handle)?; + } + Ok(self) + } +} + +#[cfg(feature = "validate")] +impl crate::arena::Range { + pub(self) fn check_valid_for(&self, arena: &Arena) -> Result<(), BadRangeError> { + arena.check_contains_range(self) + } +} diff --git a/src/valid/interface.rs b/src/valid/interface.rs index 85610b068e..289a068f75 100644 --- a/src/valid/interface.rs +++ b/src/valid/interface.rs @@ -2,7 +2,7 @@ use super::{ analyzer::{FunctionInfo, GlobalUse}, Capabilities, Disalignment, FunctionError, ModuleInfo, }; -use crate::arena::{BadHandle, Handle, UniqueArena}; +use crate::arena::{Handle, UniqueArena}; use crate::span::{AddSpan as _, MapErrWithSpan as _, SpanProvider as _, WithSpan}; use bit_set::BitSet; @@ -12,8 +12,6 @@ const MAX_WORKGROUP_SIZE: u32 = 0x4000; #[derive(Clone, Debug, thiserror::Error)] pub enum GlobalVariableError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("Usage isn't compatible with address space {0:?}")] InvalidUsage(crate::AddressSpace), #[error("Type isn't compatible with address space {0:?}")] @@ -380,10 +378,7 @@ impl super::Validator { use super::TypeFlags; log::debug!("var {:?}", var); - let type_info = self.types.get(var.ty.index()).ok_or_else(|| BadHandle { - kind: "type", - index: var.ty.index(), - })?; + let type_info = &self.types[var.ty.index()]; let (required_type_flags, is_resource) = match var.space { crate::AddressSpace::Function => { diff --git a/src/valid/mod.rs b/src/valid/mod.rs index 255d6f428e..4e62a2ca78 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -6,6 +6,7 @@ mod analyzer; mod compose; mod expression; mod function; +mod handles; mod interface; mod r#type; @@ -13,7 +14,7 @@ mod r#type; use crate::arena::{Arena, UniqueArena}; use crate::{ - arena::{BadHandle, Handle}, + arena::Handle, proc::{LayoutError, Layouter}, FastHashSet, }; @@ -31,6 +32,8 @@ pub use function::{CallError, FunctionError, LocalVariableError}; pub use interface::{EntryPointError, GlobalVariableError, VaryingError}; pub use r#type::{Disalignment, TypeError, TypeFlags}; +use self::handles::InvalidHandleError; + bitflags::bitflags! { /// Validation flags. /// @@ -146,8 +149,6 @@ pub struct Validator { #[derive(Clone, Debug, thiserror::Error)] pub enum ConstantError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("The type doesn't match the constant")] InvalidType, #[error("The component handle {0:?} can not be resolved")] @@ -160,6 +161,8 @@ pub enum ConstantError { #[derive(Clone, Debug, thiserror::Error)] pub enum ValidationError { + #[error(transparent)] + InvalidHandle(#[from] InvalidHandleError), #[error(transparent)] Layouter(#[from] LayoutError), #[error("Type {handle:?} '{name}' is invalid")] @@ -283,7 +286,7 @@ impl Validator { } } crate::ConstantInner::Composite { ty, ref components } => { - match types.get_handle(ty)?.inner { + match types[ty].inner { crate::TypeInner::Array { size: crate::ArraySize::Constant(size_handle), .. @@ -316,6 +319,9 @@ impl Validator { self.reset(); self.reset_types(module.types.len()); + #[cfg(feature = "validate")] + Self::validate_module_handles(module).map_err(|e| e.with_span())?; + self.layouter .update(&module.types, &module.constants) .map_err(|e| { diff --git a/src/valid/type.rs b/src/valid/type.rs index f103017dd9..172f110724 100644 --- a/src/valid/type.rs +++ b/src/valid/type.rs @@ -1,6 +1,6 @@ use super::Capabilities; use crate::{ - arena::{Arena, BadHandle, Handle, UniqueArena}, + arena::{Arena, Handle, UniqueArena}, proc::Alignment, }; @@ -88,8 +88,6 @@ pub enum Disalignment { #[derive(Clone, Debug, thiserror::Error)] pub enum TypeError { - #[error(transparent)] - BadHandle(#[from] BadHandle), #[error("The {0:?} scalar width {1} is not supported")] InvalidWidth(crate::ScalarKind, crate::Bytes), #[error("The {0:?} scalar width {1} is not supported for an atomic")] @@ -418,7 +416,7 @@ impl super::Validator { let sized_flag = match size { crate::ArraySize::Constant(const_handle) => { - let constant = constants.try_get(const_handle)?; + let constant = &constants[const_handle]; let length_is_positive = match *constant { crate::Constant { specialization: Some(_),