diff --git a/cli/src/main.rs b/cli/src/main.rs index abb00aaa7b..4330d882d0 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -231,7 +231,7 @@ fn run() -> Result<(), Box> { } params.keep_coordinate_space = args.keep_coordinate_space; - let module = match Path::new(&input_path) + let (module, input_text) = match Path::new(&input_path) .extension() .ok_or(CliError("Input filename has no extension"))? .to_str() @@ -246,13 +246,13 @@ fn run() -> Result<(), Box> { .map(std::path::PathBuf::from), }; let input = fs::read(input_path)?; - naga::front::spv::parse_u8_slice(&input, &options)? + naga::front::spv::parse_u8_slice(&input, &options).map(|m| (m, None))? } "wgsl" => { let input = fs::read_to_string(input_path)?; let result = naga::front::wgsl::parse_str(&input); match result { - Ok(v) => v, + Ok(v) => (v, Some(input)), Err(ref e) => { e.emit_to_stderr(&input); return Err(CliError("Could not parse WGSL").into()); @@ -263,24 +263,27 @@ fn run() -> Result<(), Box> { let input = fs::read_to_string(input_path)?; let mut parser = naga::front::glsl::Parser::default(); - parser - .parse( - &naga::front::glsl::Options { - stage: match ext { - "vert" => naga::ShaderStage::Vertex, - "frag" => naga::ShaderStage::Fragment, - "comp" => naga::ShaderStage::Compute, - _ => unreachable!(), + ( + parser + .parse( + &naga::front::glsl::Options { + stage: match ext { + "vert" => naga::ShaderStage::Vertex, + "frag" => naga::ShaderStage::Fragment, + "comp" => naga::ShaderStage::Compute, + _ => unreachable!(), + }, + defines: Default::default(), }, - defines: Default::default(), - }, - &input, - ) - .unwrap_or_else(|errors| { - let filename = input_path.file_name().and_then(std::ffi::OsStr::to_str); - emit_glsl_parser_error(errors, filename.unwrap_or("glsl"), &input); - std::process::exit(1); - }) + &input, + ) + .unwrap_or_else(|errors| { + let filename = input_path.file_name().and_then(std::ffi::OsStr::to_str); + emit_glsl_parser_error(errors, filename.unwrap_or("glsl"), &input); + std::process::exit(1); + }), + Some(input), + ) } _ => return Err(CliError("Unknown input file extension").into()), }; @@ -294,6 +297,10 @@ fn run() -> Result<(), Box> { { Ok(info) => Some(info), Err(error) => { + if let Some(input) = input_text { + let filename = input_path.file_name().and_then(std::ffi::OsStr::to_str); + emit_annotated_error(&error, filename.unwrap_or("input"), &input); + } print_err(&error); None } @@ -468,6 +475,7 @@ use codespan_reporting::{ termcolor::{ColorChoice, StandardStream}, }, }; +use naga::WithSpan; pub fn emit_glsl_parser_error(errors: Vec, filename: &str, source: &str) { let files = SimpleFile::new(filename, source); @@ -484,3 +492,20 @@ pub fn emit_glsl_parser_error(errors: Vec, filename: & term::emit(&mut writer.lock(), &config, &files, &diagnostic).expect("cannot write error"); } } + +pub fn emit_annotated_error(ann_err: &WithSpan, filename: &str, source: &str) { + let files = SimpleFile::new(filename, source); + let config = codespan_reporting::term::Config::default(); + let writer = StandardStream::stderr(ColorChoice::Auto); + + let diagnostic = Diagnostic::error().with_labels( + ann_err + .spans() + .map(|(span, desc)| { + Label::primary((), span.to_range().unwrap()).with_message(desc.to_owned()) + }) + .collect(), + ); + + term::emit(&mut writer.lock(), &config, &files, &diagnostic).expect("cannot write error"); +} diff --git a/src/arena.rs b/src/arena.rs index 536a4bd280..4884010d3c 100644 --- a/src/arena.rs +++ b/src/arena.rs @@ -31,28 +31,35 @@ impl Clone for Handle { } } } + impl Copy for Handle {} + impl PartialEq for Handle { fn eq(&self, other: &Self) -> bool { self.index == other.index } } + impl Eq for Handle {} + impl PartialOrd for Handle { fn partial_cmp(&self, other: &Self) -> Option { self.index.partial_cmp(&other.index) } } + impl Ord for Handle { fn cmp(&self, other: &Self) -> Ordering { self.index.cmp(&other.index) } } + impl fmt::Debug for Handle { fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { write!(formatter, "[{}]", self.index) } } + impl hash::Hash for Handle { fn hash(&self, hasher: &mut H) { self.index.hash(hasher) @@ -117,11 +124,13 @@ impl Clone for Range { } } } + impl fmt::Debug for Range { fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result { write!(formatter, "[{}..{}]", self.inner.start + 1, self.inner.end) } } + impl Iterator for Range { type Item = Handle; fn next(&mut self) -> Option { @@ -159,6 +168,7 @@ impl Default for Arena { Self::new() } } + impl fmt::Debug for Arena { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_map().entries(self.iter()).finish() diff --git a/src/block.rs b/src/block.rs index 1a35e43949..6a31301e11 100644 --- a/src/block.rs +++ b/src/block.rs @@ -78,6 +78,14 @@ impl Block { .splice(range.clone(), other.span_info.into_iter()); self.body.splice(range, other.body.into_iter()); } + pub fn span_iter(&self) -> impl Iterator { + #[cfg(feature = "span")] + let span_iter = self.span_info.iter(); + #[cfg(not(feature = "span"))] + let span_iter = std::iter::repeat_with(|| &Span::UNDEFINED); + + self.body.iter().zip(span_iter) + } pub fn span_iter_mut(&mut self) -> impl Iterator)> { #[cfg(feature = "span")] diff --git a/src/lib.rs b/src/lib.rs index 2e9f608314..99b77b267d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -220,7 +220,7 @@ use std::{ hash::BuildHasherDefault, }; -pub use crate::span::Span; +pub use crate::span::{Span, SpanContext, WithSpan}; #[cfg(feature = "deserialize")] use serde::Deserialize; #[cfg(feature = "serialize")] diff --git a/src/span.rs b/src/span.rs index 3eaadd212f..dabcabdb0f 100644 --- a/src/span.rs +++ b/src/span.rs @@ -1,4 +1,5 @@ -use std::ops::Range; +use crate::{Arena, Handle, UniqueArena}; +use std::{error::Error, fmt, ops::Range}; /// A source code span, used for error reporting. #[derive(Clone, Copy, Debug, PartialEq, Default)] @@ -8,6 +9,7 @@ pub struct Span { } impl Span { + pub const UNDEFINED: Self = Self { start: 0, end: 0 }; /// Creates a new `Span` from a range of byte indices /// /// Note: end is exclusive, it doesn't belong to the `Span` @@ -66,3 +68,202 @@ impl From> for Span { } } } + +/// A source code span together with "context", a user-readable description of what part of the error it refers to. +pub type SpanContext = (Span, String); + +/// Wrapper class for [`Error`], augmenting it with a list of [`SpanContext`]s. +#[derive(Debug)] +pub struct WithSpan { + inner: E, + #[cfg(feature = "span")] + spans: Vec, +} + +impl fmt::Display for WithSpan +where + E: fmt::Display, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + self.inner.fmt(f) + } +} + +#[cfg(test)] +impl PartialEq for WithSpan +where + E: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.inner.eq(&other.inner) + } +} + +impl Error for WithSpan +where + E: Error, +{ + fn source(&self) -> Option<&(dyn Error + 'static)> { + self.inner.source() + } +} + +impl WithSpan { + /// Create a new [`WithSpan`] from an [`Error`], containing no spans. + pub fn new(inner: E) -> Self { + Self { + inner, + #[cfg(feature = "span")] + spans: Vec::new(), + } + } + + /// Reverse of [`Self::new`], discards span information and returns an inner error. + pub fn into_inner(self) -> E { + self.inner + } + + /// Iterator over stored [`SpanContext`]s. + pub fn spans(&self) -> impl Iterator { + #[cfg(feature = "span")] + return self.spans.iter(); + #[cfg(not(feature = "span"))] + return std::iter::empty(); + } + + /// Add a new span with description. + #[cfg_attr(not(feature = "span"), allow(unused_variables, unused_mut))] + pub fn with_span(mut self, span: Span, description: S) -> Self + where + S: ToString, + { + #[cfg(feature = "span")] + if span.is_defined() { + self.spans.push((span, description.to_string())); + } + self + } + + /// Add a [`SpanContext`]. + pub fn with_context(self, span_context: SpanContext) -> Self { + let (span, description) = span_context; + self.with_span(span, description) + } + + /// Add a [`Handle`] from either [`Arena`] or [`UniqueArena`], borrowing its span information from there + /// and annotating with a type and the handle representation. + pub(crate) fn with_handle>(self, handle: Handle, arena: &A) -> Self { + self.with_context(arena.get_span_context(handle)) + } + + /// Convert inner error using [`From`]. + pub fn into_other(self) -> WithSpan + where + E2: From, + { + WithSpan { + inner: self.inner.into(), + #[cfg(feature = "span")] + spans: self.spans, + } + } + + /// Convert inner error into another type. Joins span information contained in `self` + /// with what is returned from `func`. + pub fn and_then(self, func: F) -> WithSpan + where + F: FnOnce(E) -> WithSpan, + { + #[cfg_attr(not(feature = "span"), allow(unused_mut))] + let mut res = func(self.inner); + #[cfg(feature = "span")] + res.spans.extend(self.spans); + res + } +} + +/// Convenience trait for [`Error`] to be able to apply spans to anything. +pub(crate) trait AddSpan: Sized { + type Output; + /// See [`WithSpan::new`]. + fn with_span(self) -> Self::Output; + /// See [`WithSpan::with_span`]. + fn with_span_static(self, span: Span, description: &'static str) -> Self::Output; + /// See [`WithSpan::with_context`]. + fn with_span_context(self, span_context: SpanContext) -> Self::Output; + /// See [`WithSpan::with_handle`]. + fn with_span_handle>(self, handle: Handle, arena: &A) -> Self::Output; +} + +/// Trait abstracting over getting a span from an [`Arena`] or a [`UniqueArena`]. +pub(crate) trait SpanProvider { + fn get_span(&self, handle: Handle) -> Span; + fn get_span_context(&self, handle: Handle) -> SpanContext { + match self.get_span(handle) { + x if !x.is_defined() => (Default::default(), "".to_string()), + known => ( + known, + format!("{} {:?}", std::any::type_name::(), handle), + ), + } + } +} + +impl SpanProvider for Arena { + fn get_span(&self, handle: Handle) -> Span { + self.get_span(handle) + } +} + +impl SpanProvider for UniqueArena { + fn get_span(&self, handle: Handle) -> Span { + self.get_span(handle) + } +} + +impl AddSpan for E +where + E: Error, +{ + type Output = WithSpan; + fn with_span(self) -> WithSpan { + WithSpan::new(self) + } + + fn with_span_static(self, span: Span, description: &'static str) -> WithSpan { + WithSpan::new(self).with_span(span, description) + } + + fn with_span_context(self, span_context: SpanContext) -> WithSpan { + WithSpan::new(self).with_context(span_context) + } + + fn with_span_handle>( + self, + handle: Handle, + arena: &A, + ) -> WithSpan { + WithSpan::new(self).with_handle(handle, arena) + } +} + +/// Convenience trait for [`Result`], adding a [`MapErrWithSpan::map_err_inner`] +/// mapping to [`WithSpan::and_then`]. +pub trait MapErrWithSpan: Sized { + type Output: Sized; + fn map_err_inner(self, func: F) -> Self::Output + where + F: FnOnce(E) -> WithSpan, + E2: From; +} + +impl MapErrWithSpan for Result> { + type Output = Result>; + fn map_err_inner(self, func: F) -> Result> + where + F: FnOnce(E) -> WithSpan, + E2: From, + { + self.map_err(|e| e.and_then(func).into_other::()) + } +} diff --git a/src/valid/analyzer.rs b/src/valid/analyzer.rs index d6e2452ad7..051bf4ca86 100644 --- a/src/valid/analyzer.rs +++ b/src/valid/analyzer.rs @@ -7,6 +7,7 @@ Figures out the following properties: !*/ use super::{CallError, ExpressionError, FunctionError, ModuleInfo, ShaderStages, ValidationFlags}; +use crate::span::{AddSpan as _, WithSpan}; use crate::{ arena::{Arena, Handle}, proc::{ResolveContext, TypeResolution}, @@ -307,7 +308,7 @@ impl FunctionInfo { info: &Self, arguments: &[Handle], expression_arena: &Arena, - ) -> Result { + ) -> Result> { for key in info.sampling_set.iter() { self.sampling_set.insert(key.clone()); } @@ -318,7 +319,10 @@ impl FunctionInfo { let handle = arguments[i as usize]; expression_arena[handle] .to_global_or_argument() - .map_err(|error| FunctionError::Expression { handle, error })? + .map_err(|error| { + FunctionError::Expression { handle, error } + .with_span_handle(handle, expression_arena) + })? } }; @@ -328,7 +332,10 @@ impl FunctionInfo { let handle = arguments[i as usize]; expression_arena[handle] .to_global_or_argument() - .map_err(|error| FunctionError::Expression { handle, error })? + .map_err(|error| { + FunctionError::Expression { handle, error } + .with_span_handle(handle, expression_arena) + })? } }; @@ -608,15 +615,15 @@ impl FunctionInfo { #[allow(clippy::or_fun_call)] fn process_block( &mut self, - statements: &[crate::Statement], + statements: &crate::Block, other_functions: &[FunctionInfo], mut disruptor: Option, expression_arena: &Arena, - ) -> Result { + ) -> Result> { use crate::Statement as S; let mut combined_uniformity = FunctionUniformity::new(); - for statement in statements { + for (statement, &span) in statements.span_iter() { let uniformity = match *statement { S::Emit(ref range) => { let mut requirements = UniformityRequirements::empty(); @@ -629,7 +636,8 @@ impl FunctionInfo { && !req.is_empty() { if let Some(cause) = disruptor { - return Err(FunctionError::NonUniformControlFlow(req, expr, cause)); + return Err(FunctionError::NonUniformControlFlow(req, expr, cause) + .with_span_handle(expr, expression_arena)); } } requirements |= req; @@ -776,7 +784,8 @@ impl FunctionInfo { FunctionError::InvalidCall { function, error: CallError::ForwardDeclaredFunction, - }, + } + .with_span_static(span, "forward call"), )?; //Note: the result is validated by the Validator, not here self.process_call(info, arguments, expression_arena)? @@ -811,7 +820,7 @@ impl ModuleInfo { fun: &crate::Function, module: &crate::Module, flags: ValidationFlags, - ) -> Result { + ) -> Result> { let mut info = FunctionInfo { flags, available_stages: ShaderStages::all(), @@ -839,7 +848,8 @@ impl ModuleInfo { &self.functions, &resolve_context, ) { - return Err(FunctionError::Expression { handle, error }); + return Err(FunctionError::Expression { handle, error } + .with_span_handle(handle, &fun.expressions)); } } @@ -979,7 +989,12 @@ fn uniform_control_flow() { .into(), }; assert_eq!( - info.process_block(&[stmt_emit1, stmt_if_uniform], &[], None, &expressions), + info.process_block( + &vec![stmt_emit1, stmt_if_uniform].into(), + &[], + None, + &expressions + ), Ok(FunctionUniformity { result: Uniformity { non_uniform_result: None, @@ -1005,12 +1020,18 @@ fn uniform_control_flow() { reject: crate::Block::new(), }; assert_eq!( - info.process_block(&[stmt_emit2, stmt_if_non_uniform], &[], None, &expressions), + info.process_block( + &vec![stmt_emit2, stmt_if_non_uniform].into(), + &[], + None, + &expressions + ), Err(FunctionError::NonUniformControlFlow( UniformityRequirements::DERIVATIVE, derivative_expr, UniformityDisruptor::Expression(non_uniform_global_expr) - )), + ) + .with_span()), ); assert_eq!(info[derivative_expr].ref_count, 1); assert_eq!(info[non_uniform_global], GlobalUse::READ); @@ -1021,7 +1042,7 @@ fn uniform_control_flow() { }; assert_eq!( info.process_block( - &[stmt_emit3, stmt_return_non_uniform], + &vec![stmt_emit3, stmt_return_non_uniform].into(), &[], Some(UniformityDisruptor::Return), &expressions @@ -1048,7 +1069,7 @@ fn uniform_control_flow() { let stmt_kill = S::Kill; assert_eq!( info.process_block( - &[stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer], + &vec![stmt_emit4, stmt_assign, stmt_kill, stmt_return_pointer].into(), &[], Some(UniformityDisruptor::Discard), &expressions diff --git a/src/valid/function.rs b/src/valid/function.rs index 4ee4a08bc9..0df8e2f561 100644 --- a/src/valid/function.rs +++ b/src/valid/function.rs @@ -6,6 +6,7 @@ use super::{ analyzer::{UniformityDisruptor, UniformityRequirements}, ExpressionError, FunctionInfo, ModuleInfo, }; +use crate::span::{AddSpan as _, MapErrWithSpan as _, WithSpan}; #[cfg(feature = "validate")] use bit_set::BitSet; @@ -201,11 +202,11 @@ impl<'a> BlockContext<'a> { &self, handle: Handle, valid_expressions: &BitSet, - ) -> Result<&crate::TypeInner, ExpressionError> { + ) -> Result<&crate::TypeInner, WithSpan> { if handle.index() >= self.expressions.len() { - Err(ExpressionError::DoesntExist) + Err(ExpressionError::DoesntExist.with_span()) } else if !valid_expressions.contains(handle.index()) { - Err(ExpressionError::NotInScope) + Err(ExpressionError::NotInScope.with_span_handle(handle, self.expressions)) } else { Ok(self.info[handle].ty.inner_with(self.types)) } @@ -215,9 +216,9 @@ impl<'a> BlockContext<'a> { &self, handle: Handle, valid_expressions: &BitSet, - ) -> Result<&crate::TypeInner, FunctionError> { + ) -> Result<&crate::TypeInner, WithSpan> { self.resolve_type_impl(handle, valid_expressions) - .map_err(|error| FunctionError::Expression { handle, error }) + .map_err_inner(|error| FunctionError::Expression { handle, error }.with_span()) } fn resolve_pointer_type( @@ -243,28 +244,33 @@ impl super::Validator { arguments: &[Handle], result: Option>, context: &BlockContext, - ) -> Result { + ) -> Result> { let fun = context .functions .try_get(function) - .ok_or(CallError::InvalidFunction)?; + .ok_or(CallError::InvalidFunction) + .map_err(WithSpan::new)?; if fun.arguments.len() != arguments.len() { return Err(CallError::ArgumentCount { required: fun.arguments.len(), seen: arguments.len(), - }); + } + .with_span()); } for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() { let ty = context .resolve_type_impl(expr, &self.valid_expression_set) - .map_err(|error| CallError::Argument { index, error })?; + .map_err_inner(|error| { + CallError::Argument { index, error }.with_span_handle(expr, context.expressions) + })?; let arg_inner = &context.types[arg.ty].inner; if !ty.equivalent(arg_inner, context.types) { return Err(CallError::ArgumentType { index, required: arg.ty, seen_expression: expr, - }); + } + .with_span_handle(expr, context.expressions)); } } @@ -272,15 +278,19 @@ impl super::Validator { if self.valid_expression_set.insert(expr.index()) { self.valid_expression_list.push(expr); } else { - return Err(CallError::ResultAlreadyInScope(expr)); + return Err(CallError::ResultAlreadyInScope(expr) + .with_span_handle(expr, context.expressions)); } match context.expressions[expr] { crate::Expression::CallResult(callee) if fun.result.is_some() && callee == function => {} - _ => return Err(CallError::ExpressionMismatch(result)), + _ => { + return Err(CallError::ExpressionMismatch(result) + .with_span_handle(expr, context.expressions)) + } } } else if fun.result.is_some() { - return Err(CallError::ExpressionMismatch(result)); + return Err(CallError::ExpressionMismatch(result).with_span()); } let callee_info = &context.prev_infos[function.index()]; @@ -295,19 +305,23 @@ impl super::Validator { value: Handle, result: Handle, context: &BlockContext, - ) -> Result<(), FunctionError> { + ) -> Result<(), WithSpan> { let pointer_inner = context.resolve_type(pointer, &self.valid_expression_set)?; let (ptr_kind, ptr_width) = match *pointer_inner { crate::TypeInner::Pointer { base, .. } => match context.types[base].inner { crate::TypeInner::Atomic { kind, width } => (kind, width), ref other => { log::error!("Atomic pointer to type {:?}", other); - return Err(AtomicError::InvalidPointer(pointer).into()); + return Err(AtomicError::InvalidPointer(pointer) + .with_span_handle(pointer, context.expressions) + .into_other()); } }, ref other => { log::error!("Atomic on type {:?}", other); - return Err(AtomicError::InvalidPointer(pointer).into()); + return Err(AtomicError::InvalidPointer(pointer) + .with_span_handle(pointer, context.expressions) + .into_other()); } }; @@ -316,21 +330,27 @@ impl super::Validator { crate::TypeInner::Scalar { width, kind } if kind == ptr_kind && width == ptr_width => {} ref other => { log::error!("Atomic operand type {:?}", other); - return Err(AtomicError::InvalidOperand(value).into()); + return Err(AtomicError::InvalidOperand(value) + .with_span_handle(value, context.expressions) + .into_other()); } } if let crate::AtomicFunction::Exchange { compare: Some(cmp) } = *fun { if context.resolve_type(cmp, &self.valid_expression_set)? != value_inner { log::error!("Atomic exchange comparison has a different type from the value"); - return Err(AtomicError::InvalidOperand(cmp).into()); + return Err(AtomicError::InvalidOperand(cmp) + .with_span_handle(cmp, context.expressions) + .into_other()); } } if self.valid_expression_set.insert(result.index()) { self.valid_expression_list.push(result); } else { - return Err(AtomicError::ResultAlreadyInScope(result).into()); + return Err(AtomicError::ResultAlreadyInScope(result) + .with_span_handle(result, context.expressions) + .into_other()); } match context.expressions[result] { //TODO: support atomic result with comparison @@ -339,7 +359,11 @@ impl super::Validator { width, comparison: false, } if kind == ptr_kind && width == ptr_width => {} - _ => return Err(AtomicError::ResultTypeMismatch(result).into()), + _ => { + return Err(AtomicError::ResultTypeMismatch(result) + .with_span_handle(result, context.expressions) + .into_other()) + } } Ok(()) } @@ -347,15 +371,16 @@ impl super::Validator { #[cfg(feature = "validate")] fn validate_block_impl( &mut self, - statements: &[crate::Statement], + statements: &crate::Block, context: &BlockContext, - ) -> Result { + ) -> Result> { use crate::{Statement as S, TypeInner as Ti}; let mut finished = false; let mut stages = super::ShaderStages::all(); - for statement in statements { + for (statement, &span) in statements.span_iter() { if finished { - return Err(FunctionError::InstructionsAfterReturn); + return Err(FunctionError::InstructionsAfterReturn + .with_span_static(span, "instructions after return")); } match *statement { S::Emit(ref range) => { @@ -363,7 +388,8 @@ impl super::Validator { if self.valid_expression_set.insert(handle.index()) { self.valid_expression_list.push(handle); } else { - return Err(FunctionError::ExpressionAlreadyInScope(handle)); + return Err(FunctionError::ExpressionAlreadyInScope(handle) + .with_span_handle(handle, context.expressions)); } } } @@ -382,7 +408,10 @@ impl super::Validator { kind: crate::ScalarKind::Bool, width: _, } => {} - _ => return Err(FunctionError::InvalidIfType(condition)), + _ => { + return Err(FunctionError::InvalidIfType(condition) + .with_span_handle(condition, context.expressions)) + } } stages &= self.validate_block(accept, context)?.stages; stages &= self.validate_block(reject, context)?.stages; @@ -401,12 +430,22 @@ impl super::Validator { kind: crate::ScalarKind::Sint, width: _, } => {} - _ => return Err(FunctionError::InvalidSwitchType(selector)), + _ => { + return Err(FunctionError::InvalidSwitchType(selector) + .with_span_handle(selector, context.expressions)) + } } self.select_cases.clear(); for case in cases { if !self.select_cases.insert(case.value) { - return Err(FunctionError::ConflictingSwitchCase(case.value)); + return Err(FunctionError::ConflictingSwitchCase(case.value) + .with_span_static( + case.body + .span_iter() + .next() + .map_or(Default::default(), |(_, s)| *s), + "conflicting switch arm here", + )); } } let pass_through_abilities = context.abilities @@ -448,19 +487,22 @@ impl super::Validator { } S::Break => { if !context.abilities.contains(ControlFlowAbility::BREAK) { - return Err(FunctionError::BreakOutsideOfLoopOrSwitch); + return Err(FunctionError::BreakOutsideOfLoopOrSwitch + .with_span_static(span, "invalid break")); } finished = true; } S::Continue => { if !context.abilities.contains(ControlFlowAbility::CONTINUE) { - return Err(FunctionError::ContinueOutsideOfLoop); + return Err(FunctionError::ContinueOutsideOfLoop + .with_span_static(span, "invalid continue")); } finished = true; } S::Return { value } => { if !context.abilities.contains(ControlFlowAbility::RETURN) { - return Err(FunctionError::InvalidReturnSpot); + return Err(FunctionError::InvalidReturnSpot + .with_span_static(span, "invalid return")); } let value_ty = value .map(|expr| context.resolve_type(expr, &self.valid_expression_set)) @@ -482,7 +524,13 @@ impl super::Validator { value_ty, expected_ty ); - return Err(FunctionError::InvalidReturnType(value)); + if let Some(handle) = value { + return Err(FunctionError::InvalidReturnType(value) + .with_span_handle(handle, context.expressions)); + } else { + return Err(FunctionError::InvalidReturnType(value) + .with_span_static(span, "invalid return")); + } } finished = true; } @@ -495,25 +543,34 @@ impl super::Validator { S::Store { pointer, value } => { let mut current = pointer; loop { - let _ = context.resolve_pointer_type(current)?; - match *context.get_expression(current)? { + let _ = context + .resolve_pointer_type(current) + .map_err(|e| e.with_span())?; + match context.expressions[current] { crate::Expression::Access { base, .. } | crate::Expression::AccessIndex { base, .. } => current = base, crate::Expression::LocalVariable(_) | crate::Expression::GlobalVariable(_) | crate::Expression::FunctionArgument(_) => break, - _ => return Err(FunctionError::InvalidStorePointer(current)), + _ => { + return Err(FunctionError::InvalidStorePointer(current) + .with_span_handle(pointer, context.expressions)) + } } } let value_ty = context.resolve_type(value, &self.valid_expression_set)?; match *value_ty { Ti::Image { .. } | Ti::Sampler { .. } => { - return Err(FunctionError::InvalidStoreValue(value)); + return Err(FunctionError::InvalidStoreValue(value) + .with_span_handle(value, context.expressions)); } _ => {} } - let good = match *context.resolve_pointer_type(pointer)? { + let good = match *context + .resolve_pointer_type(pointer) + .map_err(|e| e.with_span())? + { Ti::Pointer { base, class: _ } => match context.types[base].inner { Ti::Atomic { kind, width } => *value_ty == Ti::Scalar { kind, width }, ref other => value_ty == other, @@ -533,7 +590,10 @@ impl super::Validator { _ => false, }; if !good { - return Err(FunctionError::InvalidStoreTypes { pointer, value }); + return Err(FunctionError::InvalidStoreTypes { pointer, value } + .with_span() + .with_handle(pointer, context.expressions) + .with_handle(value, context.expressions)); } } S::ImageStore { @@ -544,14 +604,15 @@ impl super::Validator { } => { //Note: this code uses a lot of `FunctionError::InvalidImageStore`, // and could probably be refactored. - let var = match *context.get_expression(image)? { + let var = match *context.get_expression(image).map_err(|e| e.with_span())? { crate::Expression::GlobalVariable(var_handle) => { &context.global_vars[var_handle] } _ => { return Err(FunctionError::InvalidImageStore( ExpressionError::ExpectedGlobalVariable, - )) + ) + .with_span_handle(image, context.expressions)) } }; @@ -571,13 +632,15 @@ impl super::Validator { ExpressionError::InvalidImageCoordinateType( dim, coordinate, ), - )) + ) + .with_span_handle(coordinate, context.expressions)); } }; if arrayed != array_index.is_some() { return Err(FunctionError::InvalidImageStore( ExpressionError::InvalidImageArrayIndex, - )); + ) + .with_span_handle(coordinate, context.expressions)); } if let Some(expr) = array_index { match *context.resolve_type(expr, &self.valid_expression_set)? { @@ -588,7 +651,8 @@ impl super::Validator { _ => { return Err(FunctionError::InvalidImageStore( ExpressionError::InvalidImageArrayIndexType(expr), - )) + ) + .with_span_handle(expr, context.expressions)); } } } @@ -603,19 +667,24 @@ impl super::Validator { _ => { return Err(FunctionError::InvalidImageStore( ExpressionError::InvalidImageClass(class), - )) + ) + .with_span_handle(image, context.expressions)); } } } _ => { return Err(FunctionError::InvalidImageStore( ExpressionError::ExpectedImageType(var.ty), - )) + ) + .with_span() + .with_handle(var.ty, context.types) + .with_handle(image, context.expressions)) } }; if *context.resolve_type(value, &self.valid_expression_set)? != value_ty { - return Err(FunctionError::InvalidStoreValue(value)); + return Err(FunctionError::InvalidStoreValue(value) + .with_span_handle(value, context.expressions)); } } S::Call { @@ -624,7 +693,12 @@ impl super::Validator { result, } => match self.validate_call(function, arguments, result, context) { Ok(callee_stages) => stages &= callee_stages, - Err(error) => return Err(FunctionError::InvalidCall { function, error }), + Err(error) => { + return Err(error.and_then(|error| { + FunctionError::InvalidCall { function, error } + .with_span_static(span, "invalid function call") + })) + } }, S::Atomic { pointer, @@ -642,9 +716,9 @@ impl super::Validator { #[cfg(feature = "validate")] fn validate_block( &mut self, - statements: &[crate::Statement], + statements: &crate::Block, context: &BlockContext, - ) -> Result { + ) -> Result> { let base_expression_count = self.valid_expression_list.len(); let info = self.validate_block_impl(statements, context)?; for handle in self.valid_expression_list.drain(base_expression_count..) { @@ -693,9 +767,11 @@ impl super::Validator { fun: &crate::Function, module: &crate::Module, mod_info: &ModuleInfo, - ) -> Result { + ) -> Result> { #[cfg(feature = "validate")] - let mut info = mod_info.process_function(fun, module, self.flags)?; + let mut info = mod_info + .process_function(fun, module, self.flags) + .map_err(WithSpan::into_other)?; #[cfg(not(feature = "validate"))] let info = mod_info.process_function(fun, module, self.flags)?; @@ -703,10 +779,14 @@ impl super::Validator { #[cfg(feature = "validate")] for (var_handle, var) in fun.local_variables.iter() { self.validate_local_var(var, &module.types, &module.constants) - .map_err(|error| FunctionError::LocalVariable { - handle: var_handle, - name: var.name.clone().unwrap_or_default(), - error, + .map_err(|error| { + FunctionError::LocalVariable { + handle: var_handle, + name: var.name.clone().unwrap_or_default(), + error, + } + .with_span_handle(var.ty, &module.types) + .with_handle(var_handle, &fun.local_variables) })?; } @@ -719,7 +799,8 @@ impl super::Validator { return Err(FunctionError::InvalidArgumentType { index, name: argument.name.clone().unwrap_or_default(), - }); + } + .with_span_handle(argument.ty, &module.types)); } match module.types[argument.ty].inner.pointer_class() { Some(crate::StorageClass::Private) @@ -731,7 +812,8 @@ impl super::Validator { index, name: argument.name.clone().unwrap_or_default(), class: other, - }) + } + .with_span_handle(argument.ty, &module.types)) } } } @@ -753,7 +835,10 @@ impl super::Validator { &mod_info.functions, ) { Ok(stages) => info.available_stages &= stages, - Err(error) => return Err(FunctionError::Expression { handle, error }), + Err(error) => { + return Err(FunctionError::Expression { handle, error } + .with_span_handle(handle, &fun.expressions)) + } } } } diff --git a/src/valid/interface.rs b/src/valid/interface.rs index 7ab34d703f..68326b8a8d 100644 --- a/src/valid/interface.rs +++ b/src/valid/interface.rs @@ -4,6 +4,7 @@ use super::{ }; use crate::arena::{Handle, UniqueArena}; +use crate::span::{AddSpan as _, MapErrWithSpan as _, SpanProvider as _, WithSpan}; use bit_set::BitSet; #[cfg(feature = "validate")] @@ -70,8 +71,8 @@ pub enum EntryPointError { BindingCollision(Handle), #[error("Argument {0} varying error")] Argument(u32, #[source] VaryingError), - #[error("Result varying error")] - Result(#[source] VaryingError), + #[error(transparent)] + Result(#[from] VaryingError), #[error("Location {location} onterpolation of an integer has to be flat")] InvalidIntegerInterpolation { location: u32 }, #[error(transparent)] @@ -288,9 +289,12 @@ impl VaryingContext<'_> { Ok(()) } - fn validate(&mut self, binding: Option<&crate::Binding>) -> Result<(), VaryingError> { + fn validate(&mut self, binding: Option<&crate::Binding>) -> Result<(), WithSpan> { + let span_context = self.types.get_span_context(self.ty); match binding { - Some(binding) => self.validate_impl(binding), + Some(binding) => self + .validate_impl(binding) + .map_err(|e| e.with_span_context(span_context)), None => { match self.types[self.ty].inner { //TODO: check the member types @@ -301,15 +305,20 @@ impl VaryingContext<'_> { } => { for (index, member) in members.iter().enumerate() { self.ty = member.ty; + let span_context = self.types.get_span_context(self.ty); match member.binding { None => { - return Err(VaryingError::MemberMissingBinding(index as u32)) + return Err(VaryingError::MemberMissingBinding(index as u32) + .with_span_context(span_context)) } - Some(ref binding) => self.validate_impl(binding)?, + // TODO: shouldn't this be validate? + Some(ref binding) => self + .validate_impl(binding) + .map_err(|e| e.with_span_context(span_context))?, } } } - _ => return Err(VaryingError::MissingBinding), + _ => return Err(VaryingError::MissingBinding.with_span()), } Ok(()) } @@ -399,10 +408,10 @@ impl super::Validator { ep: &crate::EntryPoint, module: &crate::Module, mod_info: &ModuleInfo, - ) -> Result { + ) -> Result> { #[cfg(feature = "validate")] if ep.early_depth_test.is_some() && ep.stage != crate::ShaderStage::Fragment { - return Err(EntryPointError::UnexpectedEarlyDepthTest); + return Err(EntryPointError::UnexpectedEarlyDepthTest.with_span()); } #[cfg(feature = "validate")] @@ -412,13 +421,15 @@ impl super::Validator { .iter() .any(|&s| s == 0 || s > MAX_WORKGROUP_SIZE) { - return Err(EntryPointError::OutOfRangeWorkgroupSize); + return Err(EntryPointError::OutOfRangeWorkgroupSize.with_span()); } } else if ep.workgroup_size != [0; 3] { - return Err(EntryPointError::UnexpectedWorkgroupSize); + return Err(EntryPointError::UnexpectedWorkgroupSize.with_span()); } - let info = self.validate_function(&ep.function, module, mod_info)?; + let info = self + .validate_function(&ep.function, module, mod_info) + .map_err(WithSpan::into_other)?; #[cfg(feature = "validate")] { @@ -431,12 +442,13 @@ impl super::Validator { }; if !info.available_stages.contains(stage_bit) { - return Err(EntryPointError::ForbiddenStageOperations); + return Err(EntryPointError::ForbiddenStageOperations.with_span()); } } self.location_mask.clear(); let mut argument_built_ins = 0; + // TODO: add span info to function arguments for (index, fa) in ep.function.arguments.iter().enumerate() { let mut ctx = VaryingContext { ty: fa.ty, @@ -448,7 +460,7 @@ impl super::Validator { capabilities: self.capabilities, }; ctx.validate(fa.binding.as_ref()) - .map_err(|e| EntryPointError::Argument(index as u32, e))?; + .map_err_inner(|e| EntryPointError::Argument(index as u32, e).with_span())?; argument_built_ins = ctx.built_in_mask; } @@ -464,7 +476,7 @@ impl super::Validator { capabilities: self.capabilities, }; ctx.validate(fr.binding.as_ref()) - .map_err(EntryPointError::Result)?; + .map_err_inner(|e| EntryPointError::Result(e).with_span())?; } for bg in self.bind_group_masks.iter_mut() { @@ -499,7 +511,8 @@ impl super::Validator { allowed_usage, usage ); - return Err(EntryPointError::InvalidGlobalUsage(var_handle, usage)); + return Err(EntryPointError::InvalidGlobalUsage(var_handle, usage) + .with_span_handle(var_handle, &module.global_variables)); } if let Some(ref bind) = var.binding { @@ -507,7 +520,8 @@ impl super::Validator { self.bind_group_masks.push(BitSet::new()); } if !self.bind_group_masks[bind.group as usize].insert(bind.binding as usize) { - return Err(EntryPointError::BindingCollision(var_handle)); + return Err(EntryPointError::BindingCollision(var_handle) + .with_span_handle(var_handle, &module.global_variables)); } } } diff --git a/src/valid/mod.rs b/src/valid/mod.rs index fb72cdafcd..6b2c804757 100644 --- a/src/valid/mod.rs +++ b/src/valid/mod.rs @@ -19,6 +19,7 @@ use std::ops; //TODO: analyze the model at the same time as we validate it, // merge the corresponding matches over expressions and statements. +use crate::span::{AddSpan as _, WithSpan}; pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements}; pub use compose::ComposeError; pub use expression::ExpressionError; @@ -265,29 +266,43 @@ impl Validator { } /// Check the given module to be valid. - pub fn validate(&mut self, module: &crate::Module) -> Result { + pub fn validate( + &mut self, + module: &crate::Module, + ) -> Result> { self.reset_types(module.types.len()); - self.layouter.update(&module.types, &module.constants)?; + self.layouter + .update(&module.types, &module.constants) + .map_err(|e| { + let InvalidBaseType(handle) = e; + ValidationError::from(e).with_span_handle(handle, &module.types) + })?; #[cfg(feature = "validate")] if self.flags.contains(ValidationFlags::CONSTANTS) { for (handle, constant) in module.constants.iter() { self.validate_constant(handle, &module.constants, &module.types) - .map_err(|error| ValidationError::Constant { - handle, - name: constant.name.clone().unwrap_or_default(), - error, - })?; + .map_err(|error| { + ValidationError::Constant { + handle, + name: constant.name.clone().unwrap_or_default(), + error, + } + .with_span_handle(handle, &module.constants) + })? } } for (handle, ty) in module.types.iter() { let ty_info = self .validate_type(handle, &module.types, &module.constants) - .map_err(|error| ValidationError::Type { - handle, - name: ty.name.clone().unwrap_or_default(), - error, + .map_err(|error| { + ValidationError::Type { + handle, + name: ty.name.clone().unwrap_or_default(), + error, + } + .with_span_handle(handle, &module.types) })?; self.types[handle.index()] = ty_info; } @@ -295,10 +310,13 @@ impl Validator { #[cfg(feature = "validate")] for (var_handle, var) in module.global_variables.iter() { self.validate_global_var(var, &module.types) - .map_err(|error| ValidationError::GlobalVariable { - handle: var_handle, - name: var.name.clone().unwrap_or_default(), - error, + .map_err(|error| { + ValidationError::GlobalVariable { + handle: var_handle, + name: var.name.clone().unwrap_or_default(), + error, + } + .with_span_handle(var_handle, &module.global_variables) })?; } @@ -311,11 +329,14 @@ impl Validator { match self.validate_function(fun, module, &mod_info) { Ok(info) => mod_info.functions.push(info), Err(error) => { - return Err(ValidationError::Function { - handle, - name: fun.name.clone().unwrap_or_default(), - error, - }) + return Err(error.and_then(|error| { + ValidationError::Function { + handle, + name: fun.name.clone().unwrap_or_default(), + error, + } + .with_span_handle(handle, &module.functions) + })) } } } @@ -327,17 +348,21 @@ impl Validator { stage: ep.stage, name: ep.name.clone(), error: EntryPointError::Conflict, - }); + } + .with_span()); // TODO: keep some EP span information? } match self.validate_entry_point(ep, module, &mod_info) { Ok(info) => mod_info.entry_points.push(info), Err(error) => { - return Err(ValidationError::EntryPoint { - stage: ep.stage, - name: ep.name.clone(), - error, - }) + return Err(error.and_then(|inner| { + ValidationError::EntryPoint { + stage: ep.stage, + name: ep.name.clone(), + error: inner, + } + .with_span() + })) } } } diff --git a/tests/wgsl-errors.rs b/tests/wgsl-errors.rs index dedd2042d6..ab9232bbfd 100644 --- a/tests/wgsl-errors.rs +++ b/tests/wgsl-errors.rs @@ -580,6 +580,7 @@ fn validation_error(source: &str) -> Result