From d6ebd88f422b97fc20e4608bd98f94a3bd40055d Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Wed, 14 Feb 2024 15:17:07 +0100 Subject: [PATCH] implement override-expression evaluation for initializers of override declarations --- naga/src/arena.rs | 2 + naga/src/back/hlsl/mod.rs | 2 +- naga/src/back/hlsl/writer.rs | 9 +- naga/src/back/msl/mod.rs | 2 +- naga/src/back/msl/writer.rs | 6 +- naga/src/back/pipeline_constants.rs | 335 +++++++++++++++-- naga/src/back/spv/mod.rs | 40 +- naga/src/back/spv/writer.rs | 14 +- naga/src/front/glsl/context.rs | 23 +- naga/src/front/glsl/functions.rs | 10 + naga/src/front/glsl/parser.rs | 17 +- naga/src/front/glsl/parser/declarations.rs | 9 +- naga/src/front/glsl/parser/functions.rs | 7 +- naga/src/front/glsl/types.rs | 7 +- naga/src/front/wgsl/lower/mod.rs | 92 ++++- naga/src/proc/constant_evaluator.rs | 405 ++++++++++++++------- naga/src/proc/mod.rs | 2 +- naga/src/valid/analyzer.rs | 2 +- naga/src/valid/expression.rs | 21 +- naga/src/valid/function.rs | 20 +- naga/src/valid/handles.rs | 2 + naga/src/valid/interface.rs | 13 +- naga/src/valid/mod.rs | 41 ++- naga/src/valid/type.rs | 2 + naga/tests/in/overrides.wgsl | 2 +- naga/tests/out/analysis/overrides.info.ron | 6 + naga/tests/out/hlsl/overrides.hlsl | 1 + naga/tests/out/ir/overrides.compact.ron | 15 +- naga/tests/out/ir/overrides.ron | 15 +- naga/tests/out/msl/overrides.msl | 1 + naga/tests/out/spv/overrides.main.spvasm | 24 +- naga/tests/out/wgsl/quad_glsl.vert.wgsl | 4 +- 32 files changed, 909 insertions(+), 242 deletions(-) diff --git a/naga/src/arena.rs b/naga/src/arena.rs index 184102757e..740df85b86 100644 --- a/naga/src/arena.rs +++ b/naga/src/arena.rs @@ -122,6 +122,7 @@ impl Handle { serde(transparent) )] #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))] +#[cfg_attr(test, derive(PartialEq))] pub struct Range { inner: ops::Range, #[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(skip))] @@ -140,6 +141,7 @@ impl Range { // NOTE: Keep this diagnostic in sync with that of [`BadHandle`]. #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] #[error("Handle range {range:?} of {kind} is either not present, or inaccessible yet")] pub struct BadRangeError { // This error is used for many `Handle` types, but there's no point in making this generic, so diff --git a/naga/src/back/hlsl/mod.rs b/naga/src/back/hlsl/mod.rs index 588c91d69d..d423b003ff 100644 --- a/naga/src/back/hlsl/mod.rs +++ b/naga/src/back/hlsl/mod.rs @@ -256,7 +256,7 @@ pub enum Error { #[error("{0}")] Custom(String), #[error(transparent)] - PipelineConstant(#[from] back::pipeline_constants::PipelineConstantError), + PipelineConstant(#[from] Box), } #[derive(Default)] diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 0db6489840..1abc6ceca0 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -169,9 +169,14 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { module_info: &valid::ModuleInfo, pipeline_options: &PipelineOptions, ) -> Result { - let module = - back::pipeline_constants::process_overrides(module, &pipeline_options.constants)?; + let (module, module_info) = back::pipeline_constants::process_overrides( + module, + module_info, + &pipeline_options.constants, + ) + .map_err(Box::new)?; let module = module.as_ref(); + let module_info = module_info.as_ref(); self.reset(module); diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 702b373cfc..6ba8227a20 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -144,7 +144,7 @@ pub enum Error { #[error("ray tracing is not supported prior to MSL 2.3")] UnsupportedRayTracing, #[error(transparent)] - PipelineConstant(#[from] crate::back::pipeline_constants::PipelineConstantError), + PipelineConstant(#[from] Box), } #[derive(Clone, Debug, PartialEq, thiserror::Error)] diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 36d8bc820b..3c2a741cd4 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -3223,9 +3223,11 @@ impl Writer { options: &Options, pipeline_options: &PipelineOptions, ) -> Result { - let module = - back::pipeline_constants::process_overrides(module, &pipeline_options.constants)?; + let (module, info) = + back::pipeline_constants::process_overrides(module, info, &pipeline_options.constants) + .map_err(Box::new)?; let module = module.as_ref(); + let info = info.as_ref(); self.names.clear(); self.namer.reset( diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 5a3cad2a6d..6b2792dd28 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -1,6 +1,10 @@ use super::PipelineConstants; -use crate::{Constant, Expression, Literal, Module, Scalar, Span, TypeInner}; -use std::borrow::Cow; +use crate::{ + proc::{ConstantEvaluator, ConstantEvaluatorError}, + valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator}, + Constant, Expression, Handle, Literal, Module, Override, Scalar, Span, TypeInner, WithSpan, +}; +use std::{borrow::Cow, collections::HashSet}; use thiserror::Error; #[derive(Error, Debug, Clone)] @@ -12,48 +16,317 @@ pub enum PipelineConstantError { SrcNeedsToBeFinite, #[error("Source f64 value doesn't fit in destination")] DstRangeTooSmall, + #[error(transparent)] + ConstantEvaluatorError(#[from] ConstantEvaluatorError), + #[error(transparent)] + ValidationError(#[from] WithSpan), } pub(super) fn process_overrides<'a>( module: &'a Module, + module_info: &'a ModuleInfo, pipeline_constants: &PipelineConstants, -) -> Result, PipelineConstantError> { +) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> { if module.overrides.is_empty() { - return Ok(Cow::Borrowed(module)); + return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info))); } let mut module = module.clone(); + let mut override_map = Vec::with_capacity(module.overrides.len()); + let mut adjusted_const_expressions = Vec::with_capacity(module.const_expressions.len()); + let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len()); - for (_handle, override_, span) in module.overrides.drain() { - let key = if let Some(id) = override_.id { - Cow::Owned(id.to_string()) - } else if let Some(ref name) = override_.name { - Cow::Borrowed(name) - } else { - unreachable!(); + let mut global_expression_kind_tracker = crate::proc::ExpressionConstnessTracker::new(); + + let mut override_iter = module.overrides.drain(); + + for (old_h, expr, span) in module.const_expressions.drain() { + let mut expr = match expr { + Expression::Override(h) => { + let c_h = if let Some(new_h) = override_map.get(h.index()) { + *new_h + } else { + let mut new_h = None; + for entry in override_iter.by_ref() { + let stop = entry.0 == h; + new_h = Some(process_override( + entry, + pipeline_constants, + &mut module, + &mut override_map, + &adjusted_const_expressions, + &mut adjusted_constant_initializers, + &mut global_expression_kind_tracker, + )?); + if stop { + break; + } + } + new_h.unwrap() + }; + Expression::Constant(c_h) + } + Expression::Constant(c_h) => { + adjusted_constant_initializers.insert(c_h); + module.constants[c_h].init = adjusted_const_expressions[c_h.index()]; + expr + } + expr => expr, }; - let init = if let Some(value) = pipeline_constants.get::(&key) { - let literal = match module.types[override_.ty].inner { - TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?, - _ => unreachable!(), - }; - module - .const_expressions - .append(Expression::Literal(literal), Span::UNDEFINED) - } else if let Some(init) = override_.init { - init - } else { - return Err(PipelineConstantError::MissingValue(key.to_string())); - }; - let constant = Constant { - name: override_.name, - ty: override_.ty, - init, - }; - module.constants.append(constant, span); + let mut evaluator = ConstantEvaluator::for_wgsl_module( + &mut module, + &mut global_expression_kind_tracker, + false, + ); + adjust_expr(&adjusted_const_expressions, &mut expr); + let h = evaluator.try_eval_and_append(expr, span)?; + debug_assert_eq!(old_h.index(), adjusted_const_expressions.len()); + adjusted_const_expressions.push(h); } - Ok(Cow::Owned(module)) + for entry in override_iter { + process_override( + entry, + pipeline_constants, + &mut module, + &mut override_map, + &adjusted_const_expressions, + &mut adjusted_constant_initializers, + &mut global_expression_kind_tracker, + )?; + } + + for (_, c) in module + .constants + .iter_mut() + .filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h)) + { + c.init = adjusted_const_expressions[c.init.index()]; + } + + for (_, v) in module.global_variables.iter_mut() { + if let Some(ref mut init) = v.init { + *init = adjusted_const_expressions[init.index()]; + } + } + + let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all()); + let module_info = validator.validate(&module)?; + + Ok((Cow::Owned(module), Cow::Owned(module_info))) +} + +fn process_override( + (old_h, override_, span): (Handle, Override, Span), + pipeline_constants: &PipelineConstants, + module: &mut Module, + override_map: &mut Vec>, + adjusted_const_expressions: &[Handle], + adjusted_constant_initializers: &mut HashSet>, + global_expression_kind_tracker: &mut crate::proc::ExpressionConstnessTracker, +) -> Result, PipelineConstantError> { + let key = if let Some(id) = override_.id { + Cow::Owned(id.to_string()) + } else if let Some(ref name) = override_.name { + Cow::Borrowed(name) + } else { + unreachable!(); + }; + let init = if let Some(value) = pipeline_constants.get::(&key) { + let literal = match module.types[override_.ty].inner { + TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?, + _ => unreachable!(), + }; + let expr = module + .const_expressions + .append(Expression::Literal(literal), Span::UNDEFINED); + global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const); + expr + } else if let Some(init) = override_.init { + adjusted_const_expressions[init.index()] + } else { + return Err(PipelineConstantError::MissingValue(key.to_string())); + }; + let constant = Constant { + name: override_.name, + ty: override_.ty, + init, + }; + let h = module.constants.append(constant, span); + debug_assert_eq!(old_h.index(), override_map.len()); + override_map.push(h); + adjusted_constant_initializers.insert(h); + Ok(h) +} + +fn adjust_expr(new_pos: &[Handle], expr: &mut Expression) { + let adjust = |expr: &mut Handle| { + *expr = new_pos[expr.index()]; + }; + match *expr { + Expression::Compose { + ref mut components, .. + } => { + for c in components.iter_mut() { + adjust(c); + } + } + Expression::Access { + ref mut base, + ref mut index, + } => { + adjust(base); + adjust(index); + } + Expression::AccessIndex { ref mut base, .. } => { + adjust(base); + } + Expression::Splat { ref mut value, .. } => { + adjust(value); + } + Expression::Swizzle { ref mut vector, .. } => { + adjust(vector); + } + Expression::Load { ref mut pointer } => { + adjust(pointer); + } + Expression::ImageSample { + ref mut image, + ref mut sampler, + ref mut coordinate, + ref mut array_index, + ref mut offset, + ref mut level, + ref mut depth_ref, + .. + } => { + adjust(image); + adjust(sampler); + adjust(coordinate); + if let Some(e) = array_index.as_mut() { + adjust(e); + } + if let Some(e) = offset.as_mut() { + adjust(e); + } + match *level { + crate::SampleLevel::Exact(ref mut expr) + | crate::SampleLevel::Bias(ref mut expr) => { + adjust(expr); + } + crate::SampleLevel::Gradient { + ref mut x, + ref mut y, + } => { + adjust(x); + adjust(y); + } + _ => {} + } + if let Some(e) = depth_ref.as_mut() { + adjust(e); + } + } + Expression::ImageLoad { + ref mut image, + ref mut coordinate, + ref mut array_index, + ref mut sample, + ref mut level, + } => { + adjust(image); + adjust(coordinate); + if let Some(e) = array_index.as_mut() { + adjust(e); + } + if let Some(e) = sample.as_mut() { + adjust(e); + } + if let Some(e) = level.as_mut() { + adjust(e); + } + } + Expression::ImageQuery { + ref mut image, + ref mut query, + } => { + adjust(image); + match *query { + crate::ImageQuery::Size { ref mut level } => { + if let Some(e) = level.as_mut() { + adjust(e); + } + } + _ => {} + } + } + Expression::Unary { ref mut expr, .. } => { + adjust(expr); + } + Expression::Binary { + ref mut left, + ref mut right, + .. + } => { + adjust(left); + adjust(right); + } + Expression::Select { + ref mut condition, + ref mut accept, + ref mut reject, + } => { + adjust(condition); + adjust(accept); + adjust(reject); + } + Expression::Derivative { ref mut expr, .. } => { + adjust(expr); + } + Expression::Relational { + ref mut argument, .. + } => { + adjust(argument); + } + Expression::Math { + ref mut arg, + ref mut arg1, + ref mut arg2, + ref mut arg3, + .. + } => { + adjust(arg); + if let Some(e) = arg1.as_mut() { + adjust(e); + } + if let Some(e) = arg2.as_mut() { + adjust(e); + } + if let Some(e) = arg3.as_mut() { + adjust(e); + } + } + Expression::As { ref mut expr, .. } => { + adjust(expr); + } + Expression::ArrayLength(ref mut expr) => { + adjust(expr); + } + Expression::RayQueryGetIntersection { ref mut query, .. } => { + adjust(query); + } + Expression::Literal(_) + | Expression::FunctionArgument(_) + | Expression::GlobalVariable(_) + | Expression::LocalVariable(_) + | Expression::CallResult(_) + | Expression::RayQueryProceedResult + | Expression::Constant(_) + | Expression::Override(_) + | Expression::ZeroValue(_) + | Expression::AtomicResult { .. } + | Expression::WorkGroupUniformLoadResult { .. } => {} + } } fn map_value_to_literal(value: f64, scalar: Scalar) -> Result { diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index 3c0332d59d..f1bbaecce1 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -71,7 +71,7 @@ pub enum Error { #[error("module is not validated properly: {0}")] Validation(&'static str), #[error(transparent)] - PipelineConstant(#[from] crate::back::pipeline_constants::PipelineConstantError), + PipelineConstant(#[from] Box), } #[derive(Default)] @@ -529,6 +529,42 @@ struct FunctionArgument { handle_id: Word, } +/// Tracks the expressions for which the backend emits the following instructions: +/// - OpConstantTrue +/// - OpConstantFalse +/// - OpConstant +/// - OpConstantComposite +/// - OpConstantNull +struct ExpressionConstnessTracker { + inner: bit_set::BitSet, +} + +impl ExpressionConstnessTracker { + fn from_arena(arena: &crate::Arena) -> Self { + let mut inner = bit_set::BitSet::new(); + for (handle, expr) in arena.iter() { + let insert = match *expr { + crate::Expression::Literal(_) + | crate::Expression::ZeroValue(_) + | crate::Expression::Constant(_) => true, + crate::Expression::Compose { ref components, .. } => { + components.iter().all(|h| inner.contains(h.index())) + } + crate::Expression::Splat { value, .. } => inner.contains(value.index()), + _ => false, + }; + if insert { + inner.insert(handle.index()); + } + } + Self { inner } + } + + fn is_const(&self, value: Handle) -> bool { + self.inner.contains(value.index()) + } +} + /// General information needed to emit SPIR-V for Naga statements. struct BlockContext<'w> { /// The writer handling the module to which this code belongs. @@ -554,7 +590,7 @@ struct BlockContext<'w> { temp_list: Vec, /// Tracks the constness of `Expression`s residing in `self.ir_function.expressions` - expression_constness: crate::proc::ExpressionConstnessTracker, + expression_constness: ExpressionConstnessTracker, } impl BlockContext<'_> { diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 975aa625d0..868fad7fa2 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -615,7 +615,7 @@ impl Writer { // Steal the Writer's temp list for a bit. temp_list: std::mem::take(&mut self.temp_list), writer: self, - expression_constness: crate::proc::ExpressionConstnessTracker::from_arena( + expression_constness: super::ExpressionConstnessTracker::from_arena( &ir_function.expressions, ), }; @@ -2029,15 +2029,21 @@ impl Writer { debug_info: &Option, words: &mut Vec, ) -> Result<(), Error> { - let ir_module = if let Some(pipeline_options) = pipeline_options { + let (ir_module, info) = if let Some(pipeline_options) = pipeline_options { crate::back::pipeline_constants::process_overrides( ir_module, + info, &pipeline_options.constants, - )? + ) + .map_err(Box::new)? } else { - std::borrow::Cow::Borrowed(ir_module) + ( + std::borrow::Cow::Borrowed(ir_module), + std::borrow::Cow::Borrowed(info), + ) }; let ir_module = ir_module.as_ref(); + let info = info.as_ref(); self.reset(); diff --git a/naga/src/front/glsl/context.rs b/naga/src/front/glsl/context.rs index a3b4e0edde..0c370cd5e5 100644 --- a/naga/src/front/glsl/context.rs +++ b/naga/src/front/glsl/context.rs @@ -77,12 +77,19 @@ pub struct Context<'a> { pub body: Block, pub module: &'a mut crate::Module, pub is_const: bool, - /// Tracks the constness of `Expression`s residing in `self.expressions` - pub expression_constness: crate::proc::ExpressionConstnessTracker, + /// Tracks the expression kind of `Expression`s residing in `self.expressions` + pub local_expression_kind_tracker: crate::proc::ExpressionConstnessTracker, + /// Tracks the expression kind of `Expression`s residing in `self.module.const_expressions` + pub global_expression_kind_tracker: &'a mut crate::proc::ExpressionConstnessTracker, } impl<'a> Context<'a> { - pub fn new(frontend: &Frontend, module: &'a mut crate::Module, is_const: bool) -> Result { + pub fn new( + frontend: &Frontend, + module: &'a mut crate::Module, + is_const: bool, + global_expression_kind_tracker: &'a mut crate::proc::ExpressionConstnessTracker, + ) -> Result { let mut this = Context { expressions: Arena::new(), locals: Arena::new(), @@ -101,7 +108,8 @@ impl<'a> Context<'a> { body: Block::new(), module, is_const: false, - expression_constness: crate::proc::ExpressionConstnessTracker::new(), + local_expression_kind_tracker: crate::proc::ExpressionConstnessTracker::new(), + global_expression_kind_tracker, }; this.emit_start(); @@ -249,12 +257,15 @@ impl<'a> Context<'a> { pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result> { let mut eval = if self.is_const { - crate::proc::ConstantEvaluator::for_glsl_module(self.module) + crate::proc::ConstantEvaluator::for_glsl_module( + self.module, + self.global_expression_kind_tracker, + ) } else { crate::proc::ConstantEvaluator::for_glsl_function( self.module, &mut self.expressions, - &mut self.expression_constness, + &mut self.local_expression_kind_tracker, &mut self.emitter, &mut self.body, ) diff --git a/naga/src/front/glsl/functions.rs b/naga/src/front/glsl/functions.rs index 01846eb814..fa1bbef56b 100644 --- a/naga/src/front/glsl/functions.rs +++ b/naga/src/front/glsl/functions.rs @@ -1236,6 +1236,8 @@ impl Frontend { let pointer = ctx .expressions .append(Expression::GlobalVariable(arg.handle), Default::default()); + ctx.local_expression_kind_tracker + .insert(pointer, crate::proc::ExpressionKind::Runtime); let ty = ctx.module.global_variables[arg.handle].ty; @@ -1256,6 +1258,8 @@ impl Frontend { let value = ctx .expressions .append(Expression::FunctionArgument(idx), Default::default()); + ctx.local_expression_kind_tracker + .insert(value, crate::proc::ExpressionKind::Runtime); ctx.body .push(Statement::Store { pointer, value }, Default::default()); }, @@ -1285,6 +1289,8 @@ impl Frontend { let pointer = ctx .expressions .append(Expression::GlobalVariable(arg.handle), Default::default()); + ctx.local_expression_kind_tracker + .insert(pointer, crate::proc::ExpressionKind::Runtime); let ty = ctx.module.global_variables[arg.handle].ty; @@ -1307,6 +1313,8 @@ impl Frontend { let load = ctx .expressions .append(Expression::Load { pointer }, Default::default()); + ctx.local_expression_kind_tracker + .insert(load, crate::proc::ExpressionKind::Runtime); ctx.body.push( Statement::Emit(ctx.expressions.range_from(len)), Default::default(), @@ -1329,6 +1337,8 @@ impl Frontend { let res = ctx .expressions .append(Expression::Compose { ty, components }, Default::default()); + ctx.local_expression_kind_tracker + .insert(res, crate::proc::ExpressionKind::Runtime); ctx.body.push( Statement::Emit(ctx.expressions.range_from(len)), Default::default(), diff --git a/naga/src/front/glsl/parser.rs b/naga/src/front/glsl/parser.rs index 851d2e1d79..d4eb39b39b 100644 --- a/naga/src/front/glsl/parser.rs +++ b/naga/src/front/glsl/parser.rs @@ -164,9 +164,15 @@ impl<'source> ParsingContext<'source> { pub fn parse(&mut self, frontend: &mut Frontend) -> Result { let mut module = Module::default(); + let mut global_expression_kind_tracker = crate::proc::ExpressionConstnessTracker::new(); // Body and expression arena for global initialization - let mut ctx = Context::new(frontend, &mut module, false)?; + let mut ctx = Context::new( + frontend, + &mut module, + false, + &mut global_expression_kind_tracker, + )?; while self.peek(frontend).is_some() { self.parse_external_declaration(frontend, &mut ctx)?; @@ -196,7 +202,11 @@ impl<'source> ParsingContext<'source> { frontend: &mut Frontend, ctx: &mut Context, ) -> Result<(u32, Span)> { - let (const_expr, meta) = self.parse_constant_expression(frontend, ctx.module)?; + let (const_expr, meta) = self.parse_constant_expression( + frontend, + ctx.module, + ctx.global_expression_kind_tracker, + )?; let res = ctx.module.to_ctx().eval_expr_to_u32(const_expr); @@ -219,8 +229,9 @@ impl<'source> ParsingContext<'source> { &mut self, frontend: &mut Frontend, module: &mut Module, + global_expression_kind_tracker: &mut crate::proc::ExpressionConstnessTracker, ) -> Result<(Handle, Span)> { - let mut ctx = Context::new(frontend, module, true)?; + let mut ctx = Context::new(frontend, module, true, global_expression_kind_tracker)?; let mut stmt_ctx = ctx.stmt_ctx(); let expr = self.parse_conditional(frontend, &mut ctx, &mut stmt_ctx, None)?; diff --git a/naga/src/front/glsl/parser/declarations.rs b/naga/src/front/glsl/parser/declarations.rs index f5e38fb016..2d253a378d 100644 --- a/naga/src/front/glsl/parser/declarations.rs +++ b/naga/src/front/glsl/parser/declarations.rs @@ -251,7 +251,7 @@ impl<'source> ParsingContext<'source> { init.and_then(|expr| ctx.ctx.lift_up_const_expression(expr).ok()); late_initializer = None; } else if let Some(init) = init { - if ctx.is_inside_loop || !ctx.ctx.expression_constness.is_const(init) { + if ctx.is_inside_loop || !ctx.ctx.local_expression_kind_tracker.is_const(init) { decl_initializer = None; late_initializer = Some(init); } else { @@ -326,7 +326,12 @@ impl<'source> ParsingContext<'source> { let result = ty.map(|ty| FunctionResult { ty, binding: None }); - let mut context = Context::new(frontend, ctx.module, false)?; + let mut context = Context::new( + frontend, + ctx.module, + false, + ctx.global_expression_kind_tracker, + )?; self.parse_function_args(frontend, &mut context)?; diff --git a/naga/src/front/glsl/parser/functions.rs b/naga/src/front/glsl/parser/functions.rs index d428d74761..6d3b9d7ba4 100644 --- a/naga/src/front/glsl/parser/functions.rs +++ b/naga/src/front/glsl/parser/functions.rs @@ -192,8 +192,11 @@ impl<'source> ParsingContext<'source> { TokenValue::Case => { self.bump(frontend)?; - let (const_expr, meta) = - self.parse_constant_expression(frontend, ctx.module)?; + let (const_expr, meta) = self.parse_constant_expression( + frontend, + ctx.module, + ctx.global_expression_kind_tracker, + )?; match ctx.module.const_expressions[const_expr] { Expression::Literal(Literal::I32(value)) => match uint { diff --git a/naga/src/front/glsl/types.rs b/naga/src/front/glsl/types.rs index e87d76fffc..8a04b23839 100644 --- a/naga/src/front/glsl/types.rs +++ b/naga/src/front/glsl/types.rs @@ -330,7 +330,7 @@ impl Context<'_> { expr: Handle, ) -> Result> { let meta = self.expressions.get_span(expr); - Ok(match self.expressions[expr] { + let h = match self.expressions[expr] { ref expr @ (Expression::Literal(_) | Expression::Constant(_) | Expression::ZeroValue(_)) => self.module.const_expressions.append(expr.clone(), meta), @@ -355,6 +355,9 @@ impl Context<'_> { meta, }) } - }) + }; + self.global_expression_kind_tracker + .insert(h, crate::proc::ExpressionKind::Const); + Ok(h) } } diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 29a87751ca..662e318f8b 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -86,6 +86,8 @@ pub struct GlobalContext<'source, 'temp, 'out> { module: &'out mut crate::Module, const_typifier: &'temp mut Typifier, + + global_expression_kind_tracker: &'temp mut crate::proc::ExpressionConstnessTracker, } impl<'source> GlobalContext<'source, '_, '_> { @@ -97,6 +99,19 @@ impl<'source> GlobalContext<'source, '_, '_> { module: self.module, const_typifier: self.const_typifier, expr_type: ExpressionContextType::Constant, + global_expression_kind_tracker: self.global_expression_kind_tracker, + } + } + + fn as_override(&mut self) -> ExpressionContext<'source, '_, '_> { + ExpressionContext { + ast_expressions: self.ast_expressions, + globals: self.globals, + types: self.types, + module: self.module, + const_typifier: self.const_typifier, + expr_type: ExpressionContextType::Override, + global_expression_kind_tracker: self.global_expression_kind_tracker, } } @@ -165,6 +180,7 @@ pub struct StatementContext<'source, 'temp, 'out> { /// we should consider them to be const. See the use of `force_non_const` in /// the code for lowering `let` bindings. expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker, + global_expression_kind_tracker: &'temp mut crate::proc::ExpressionConstnessTracker, } impl<'a, 'temp> StatementContext<'a, 'temp, '_> { @@ -181,6 +197,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> { types: self.types, ast_expressions: self.ast_expressions, const_typifier: self.const_typifier, + global_expression_kind_tracker: self.global_expression_kind_tracker, module: self.module, expr_type: ExpressionContextType::Runtime(RuntimeExpressionContext { local_table: self.local_table, @@ -200,6 +217,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> { types: self.types, module: self.module, const_typifier: self.const_typifier, + global_expression_kind_tracker: self.global_expression_kind_tracker, } } @@ -253,6 +271,14 @@ pub enum ExpressionContextType<'temp, 'out> { /// available in the [`ExpressionContext`], so this variant /// carries no further information. Constant, + + /// We are lowering to an override expression, to be included in the module's + /// constant expression arena. + /// + /// Everything override expressions are allowed to refer to is + /// available in the [`ExpressionContext`], so this variant + /// carries no further information. + Override, } /// State for lowering an [`ast::Expression`] to Naga IR. @@ -311,6 +337,7 @@ pub struct ExpressionContext<'source, 'temp, 'out> { /// /// [`module::const_expressions`]: crate::Module::const_expressions const_typifier: &'temp mut Typifier, + global_expression_kind_tracker: &'temp mut crate::proc::ExpressionConstnessTracker, /// Whether we are lowering a constant expression or a general /// runtime expression, and the data needed in each case. @@ -326,6 +353,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { const_typifier: self.const_typifier, module: self.module, expr_type: ExpressionContextType::Constant, + global_expression_kind_tracker: self.global_expression_kind_tracker, } } @@ -336,6 +364,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { types: self.types, module: self.module, const_typifier: self.const_typifier, + global_expression_kind_tracker: self.global_expression_kind_tracker, } } @@ -348,7 +377,16 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { rctx.emitter, rctx.block, ), - ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(self.module), + ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module( + self.module, + self.global_expression_kind_tracker, + false, + ), + ExpressionContextType::Override => ConstantEvaluator::for_wgsl_module( + self.module, + self.global_expression_kind_tracker, + true, + ), } } @@ -375,20 +413,25 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { .ok() } ExpressionContextType::Constant => self.module.to_ctx().eval_expr_to_u32(handle).ok(), + ExpressionContextType::Override => None, } } fn get_expression_span(&self, handle: Handle) -> Span { match self.expr_type { ExpressionContextType::Runtime(ref ctx) => ctx.function.expressions.get_span(handle), - ExpressionContextType::Constant => self.module.const_expressions.get_span(handle), + ExpressionContextType::Constant | ExpressionContextType::Override => { + self.module.const_expressions.get_span(handle) + } } } fn typifier(&self) -> &Typifier { match self.expr_type { ExpressionContextType::Runtime(ref ctx) => ctx.typifier, - ExpressionContextType::Constant => self.const_typifier, + ExpressionContextType::Constant | ExpressionContextType::Override => { + self.const_typifier + } } } @@ -398,7 +441,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { ) -> Result<&mut RuntimeExpressionContext<'temp, 'out>, Error<'source>> { match self.expr_type { ExpressionContextType::Runtime(ref mut ctx) => Ok(ctx), - ExpressionContextType::Constant => Err(Error::UnexpectedOperationInConstContext(span)), + ExpressionContextType::Constant | ExpressionContextType::Override => { + Err(Error::UnexpectedOperationInConstContext(span)) + } } } @@ -435,7 +480,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { } // This means a `gather` operation appeared in a constant expression. // This error refers to the `gather` itself, not its "component" argument. - ExpressionContextType::Constant => { + ExpressionContextType::Constant | ExpressionContextType::Override => { Err(Error::UnexpectedOperationInConstContext(gather_span)) } } @@ -461,7 +506,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { // to also borrow self.module.types mutably below. let typifier = match self.expr_type { ExpressionContextType::Runtime(ref ctx) => ctx.typifier, - ExpressionContextType::Constant => &*self.const_typifier, + ExpressionContextType::Constant | ExpressionContextType::Override => { + &*self.const_typifier + } }; Ok(typifier.register_type(handle, &mut self.module.types)) } @@ -504,7 +551,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { typifier = &mut *ctx.typifier; expressions = &ctx.function.expressions; } - ExpressionContextType::Constant => { + ExpressionContextType::Constant | ExpressionContextType::Override => { resolve_ctx = ResolveContext::with_locals(self.module, &empty_arena, &[]); typifier = self.const_typifier; expressions = &self.module.const_expressions; @@ -600,14 +647,14 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { rctx.block .extend(rctx.emitter.finish(&rctx.function.expressions)); } - ExpressionContextType::Constant => {} + ExpressionContextType::Constant | ExpressionContextType::Override => {} } let result = self.append_expression(expression, span); match self.expr_type { ExpressionContextType::Runtime(ref mut rctx) => { rctx.emitter.start(&rctx.function.expressions); } - ExpressionContextType::Constant => {} + ExpressionContextType::Constant | ExpressionContextType::Override => {} } result } @@ -852,6 +899,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { types: &tu.types, module: &mut module, const_typifier: &mut Typifier::new(), + global_expression_kind_tracker: &mut crate::proc::ExpressionConstnessTracker::new(), }; for decl_handle in self.index.visit_ordered() { @@ -959,7 +1007,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ast::GlobalDeclKind::Override(ref o) => { let init = o .init - .map(|init| self.expression(init, &mut ctx.as_const())) + .map(|init| self.expression(init, &mut ctx.as_override())) .transpose()?; let inferred_type = init .map(|init| ctx.as_const().register_type(init)) @@ -1049,6 +1097,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { let mut local_table = FastHashMap::default(); let mut expressions = Arena::new(); let mut named_expressions = FastIndexMap::default(); + let mut local_expression_kind_tracker = crate::proc::ExpressionConstnessTracker::new(); let arguments = f .arguments @@ -1060,6 +1109,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .append(crate::Expression::FunctionArgument(i as u32), arg.name.span); local_table.insert(arg.handle, Typed::Plain(expr)); named_expressions.insert(expr, (arg.name.name.to_string(), arg.name.span)); + local_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Runtime); Ok(crate::FunctionArgument { name: Some(arg.name.name.to_string()), @@ -1102,7 +1152,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { named_expressions: &mut named_expressions, types: ctx.types, module: ctx.module, - expression_constness: &mut crate::proc::ExpressionConstnessTracker::new(), + expression_constness: &mut local_expression_kind_tracker, + global_expression_kind_tracker: ctx.global_expression_kind_tracker, }; let mut body = self.block(&f.body, false, &mut stmt_ctx)?; ensure_block_returns(&mut body); @@ -1518,6 +1569,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .function .expressions .append(crate::Expression::Binary { op, left, right }, stmt.span); + rctx.expression_constness + .insert(left, crate::proc::ExpressionKind::Runtime); + rctx.expression_constness + .insert(value, crate::proc::ExpressionKind::Runtime); block.extend(emitter.finish(&ctx.function.expressions)); crate::Statement::Store { @@ -1611,7 +1666,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { LoweredGlobalDecl::Const(handle) => { Typed::Plain(crate::Expression::Constant(handle)) } - _ => { + LoweredGlobalDecl::Override(handle) => { + Typed::Plain(crate::Expression::Override(handle)) + } + LoweredGlobalDecl::Function(_) + | LoweredGlobalDecl::Type(_) + | LoweredGlobalDecl::EntryPoint => { return Err(Error::Unexpected(span, ExpectedToken::Variable)); } }; @@ -1886,9 +1946,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { rctx.block .extend(rctx.emitter.finish(&rctx.function.expressions)); let result = has_result.then(|| { - rctx.function + let result = rctx + .function .expressions - .append(crate::Expression::CallResult(function), span) + .append(crate::Expression::CallResult(function), span); + rctx.expression_constness + .insert(result, crate::proc::ExpressionKind::Runtime); + result }); rctx.emitter.start(&rctx.function.expressions); rctx.block.push( diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index a9c873afbc..6318a57c00 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -253,9 +253,9 @@ gen_component_wise_extractor! { } #[derive(Debug)] -enum Behavior { - Wgsl, - Glsl, +enum Behavior<'a> { + Wgsl(WgslRestrictions<'a>), + Glsl(GlslRestrictions<'a>), } /// A context for evaluating constant expressions. @@ -278,7 +278,7 @@ enum Behavior { #[derive(Debug)] pub struct ConstantEvaluator<'a> { /// Which language's evaluation rules we should follow. - behavior: Behavior, + behavior: Behavior<'a>, /// The module's type arena. /// @@ -297,65 +297,145 @@ pub struct ConstantEvaluator<'a> { /// The arena to which we are contributing expressions. expressions: &'a mut Arena, - /// When `self.expressions` refers to a function's local expression - /// arena, this needs to be populated - function_local_data: Option>, + /// Tracks the constness of expressions residing in [`Self::expressions`] + expression_kind_tracker: &'a mut ExpressionConstnessTracker, +} + +#[derive(Debug)] +enum WgslRestrictions<'a> { + /// - const-expressions will be evaluated and inserted in the arena + Const, + /// - const-expressions will be evaluated and inserted in the arena + /// - override-expressions will be inserted in the arena + Override, + /// - const-expressions will be evaluated and inserted in the arena + /// - override-expressions will be inserted in the arena + /// - runtime-expressions will be inserted in the arena + Runtime(FunctionLocalData<'a>), +} + +#[derive(Debug)] +enum GlslRestrictions<'a> { + /// - const-expressions will be evaluated and inserted in the arena + Const, + /// - const-expressions will be evaluated and inserted in the arena + /// - override-expressions will be inserted in the arena + /// - runtime-expressions will be inserted in the arena + Runtime(FunctionLocalData<'a>), } #[derive(Debug)] struct FunctionLocalData<'a> { /// Global constant expressions const_expressions: &'a Arena, - /// Tracks the constness of expressions residing in `ConstantEvaluator.expressions` - expression_constness: &'a mut ExpressionConstnessTracker, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, } +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +pub enum ExpressionKind { + Const, + Override, + Runtime, +} + #[derive(Debug)] pub struct ExpressionConstnessTracker { - inner: bit_set::BitSet, + inner: Vec, } impl ExpressionConstnessTracker { - pub fn new() -> Self { - Self { - inner: bit_set::BitSet::new(), - } + pub const fn new() -> Self { + Self { inner: Vec::new() } } /// Forces the the expression to not be const pub fn force_non_const(&mut self, value: Handle) { - self.inner.remove(value.index()); + self.inner[value.index()] = ExpressionKind::Runtime; } - fn insert(&mut self, value: Handle) { - self.inner.insert(value.index()); + pub fn insert(&mut self, value: Handle, expr_type: ExpressionKind) { + assert_eq!(self.inner.len(), value.index()); + self.inner.push(expr_type); + } + pub fn is_const(&self, h: Handle) -> bool { + matches!(self.type_of(h), ExpressionKind::Const) } - pub fn is_const(&self, value: Handle) -> bool { - self.inner.contains(value.index()) + pub fn is_const_or_override(&self, h: Handle) -> bool { + matches!( + self.type_of(h), + ExpressionKind::Const | ExpressionKind::Override + ) + } + + fn type_of(&self, value: Handle) -> ExpressionKind { + self.inner[value.index()] } pub fn from_arena(arena: &Arena) -> Self { - let mut tracker = Self::new(); - for (handle, expr) in arena.iter() { - let insert = match *expr { - crate::Expression::Literal(_) - | crate::Expression::ZeroValue(_) - | crate::Expression::Constant(_) => true, - crate::Expression::Compose { ref components, .. } => { - components.iter().all(|h| tracker.is_const(*h)) - } - crate::Expression::Splat { value, .. } => tracker.is_const(value), - _ => false, - }; - if insert { - tracker.insert(handle); - } + let mut tracker = Self { + inner: Vec::with_capacity(arena.len()), + }; + for (_, expr) in arena.iter() { + tracker.inner.push(tracker.type_of_with_expr(expr)); } tracker } + + fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind { + match *expr { + Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { + ExpressionKind::Const + } + Expression::Override(_) => ExpressionKind::Override, + Expression::Compose { ref components, .. } => { + let mut expr_type = ExpressionKind::Const; + for component in components { + expr_type = expr_type.max(self.type_of(*component)) + } + expr_type + } + Expression::Splat { value, .. } => self.type_of(value), + Expression::AccessIndex { base, .. } => self.type_of(base), + Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)), + Expression::Swizzle { vector, .. } => self.type_of(vector), + Expression::Unary { expr, .. } => self.type_of(expr), + Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)), + Expression::Math { + arg, + arg1, + arg2, + arg3, + .. + } => self + .type_of(arg) + .max( + arg1.map(|arg| self.type_of(arg)) + .unwrap_or(ExpressionKind::Const), + ) + .max( + arg2.map(|arg| self.type_of(arg)) + .unwrap_or(ExpressionKind::Const), + ) + .max( + arg3.map(|arg| self.type_of(arg)) + .unwrap_or(ExpressionKind::Const), + ), + Expression::As { expr, .. } => self.type_of(expr), + Expression::Select { + condition, + accept, + reject, + } => self + .type_of(condition) + .max(self.type_of(accept)) + .max(self.type_of(reject)), + Expression::Relational { argument, .. } => self.type_of(argument), + Expression::ArrayLength(expr) => self.type_of(expr), + _ => ExpressionKind::Runtime, + } + } } #[derive(Clone, Debug, thiserror::Error)] @@ -436,6 +516,12 @@ pub enum ConstantEvaluatorError { ShiftedMoreThan32Bits, #[error(transparent)] Literal(#[from] crate::valid::LiteralError), + #[error("Can't use pipeline-overridable constants in const-expressions")] + Override, + #[error("Unexpected runtime-expression")] + RuntimeExpr, + #[error("Unexpected override-expression")] + OverrideExpr, } impl<'a> ConstantEvaluator<'a> { @@ -443,26 +529,49 @@ impl<'a> ConstantEvaluator<'a> { /// constant expression arena. /// /// Report errors according to WGSL's rules for constant evaluation. - pub fn for_wgsl_module(module: &'a mut crate::Module) -> Self { - Self::for_module(Behavior::Wgsl, module) + pub fn for_wgsl_module( + module: &'a mut crate::Module, + global_expression_kind_tracker: &'a mut ExpressionConstnessTracker, + in_override_ctx: bool, + ) -> Self { + Self::for_module( + Behavior::Wgsl(if in_override_ctx { + WgslRestrictions::Override + } else { + WgslRestrictions::Const + }), + module, + global_expression_kind_tracker, + ) } /// Return a [`ConstantEvaluator`] that will add expressions to `module`'s /// constant expression arena. /// /// Report errors according to GLSL's rules for constant evaluation. - pub fn for_glsl_module(module: &'a mut crate::Module) -> Self { - Self::for_module(Behavior::Glsl, module) + pub fn for_glsl_module( + module: &'a mut crate::Module, + global_expression_kind_tracker: &'a mut ExpressionConstnessTracker, + ) -> Self { + Self::for_module( + Behavior::Glsl(GlslRestrictions::Const), + module, + global_expression_kind_tracker, + ) } - fn for_module(behavior: Behavior, module: &'a mut crate::Module) -> Self { + fn for_module( + behavior: Behavior<'a>, + module: &'a mut crate::Module, + global_expression_kind_tracker: &'a mut ExpressionConstnessTracker, + ) -> Self { Self { behavior, types: &mut module.types, constants: &module.constants, overrides: &module.overrides, expressions: &mut module.const_expressions, - function_local_data: None, + expression_kind_tracker: global_expression_kind_tracker, } } @@ -473,18 +582,22 @@ impl<'a> ConstantEvaluator<'a> { pub fn for_wgsl_function( module: &'a mut crate::Module, expressions: &'a mut Arena, - expression_constness: &'a mut ExpressionConstnessTracker, + local_expression_kind_tracker: &'a mut ExpressionConstnessTracker, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, ) -> Self { - Self::for_function( - Behavior::Wgsl, - module, + Self { + behavior: Behavior::Wgsl(WgslRestrictions::Runtime(FunctionLocalData { + const_expressions: &module.const_expressions, + emitter, + block, + })), + types: &mut module.types, + constants: &module.constants, + overrides: &module.overrides, expressions, - expression_constness, - emitter, - block, - ) + expression_kind_tracker: local_expression_kind_tracker, + } } /// Return a [`ConstantEvaluator`] that will add expressions to `function`'s @@ -494,40 +607,21 @@ impl<'a> ConstantEvaluator<'a> { pub fn for_glsl_function( module: &'a mut crate::Module, expressions: &'a mut Arena, - expression_constness: &'a mut ExpressionConstnessTracker, - emitter: &'a mut super::Emitter, - block: &'a mut crate::Block, - ) -> Self { - Self::for_function( - Behavior::Glsl, - module, - expressions, - expression_constness, - emitter, - block, - ) - } - - fn for_function( - behavior: Behavior, - module: &'a mut crate::Module, - expressions: &'a mut Arena, - expression_constness: &'a mut ExpressionConstnessTracker, + local_expression_kind_tracker: &'a mut ExpressionConstnessTracker, emitter: &'a mut super::Emitter, block: &'a mut crate::Block, ) -> Self { Self { - behavior, + behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData { + const_expressions: &module.const_expressions, + emitter, + block, + })), types: &mut module.types, constants: &module.constants, overrides: &module.overrides, expressions, - function_local_data: Some(FunctionLocalData { - const_expressions: &module.const_expressions, - expression_constness, - emitter, - block, - }), + expression_kind_tracker: local_expression_kind_tracker, } } @@ -536,19 +630,17 @@ impl<'a> ConstantEvaluator<'a> { types: self.types, constants: self.constants, overrides: self.overrides, - const_expressions: match self.function_local_data { - Some(ref data) => data.const_expressions, + const_expressions: match self.function_local_data() { + Some(data) => data.const_expressions, None => self.expressions, }, } } fn check(&self, expr: Handle) -> Result<(), ConstantEvaluatorError> { - if let Some(ref function_local_data) = self.function_local_data { - if !function_local_data.expression_constness.is_const(expr) { - log::debug!("check: SubexpressionsAreNotConstant"); - return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant); - } + if !self.expression_kind_tracker.is_const(expr) { + log::debug!("check: SubexpressionsAreNotConstant"); + return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant); } Ok(()) } @@ -561,7 +653,7 @@ impl<'a> ConstantEvaluator<'a> { Expression::Constant(c) => { // Are we working in a function's expression arena, or the // module's constant expression arena? - if let Some(ref function_local_data) = self.function_local_data { + if let Some(function_local_data) = self.function_local_data() { // Deep-copy the constant's value into our arena. self.copy_from( self.constants[c].init, @@ -607,14 +699,56 @@ impl<'a> ConstantEvaluator<'a> { expr: Expression, span: Span, ) -> Result, ConstantEvaluatorError> { - let res = self.try_eval_and_append_impl(&expr, span); - if self.function_local_data.is_some() { - match res { - Ok(h) => Ok(h), - Err(_) => Ok(self.append_expr(expr, span, false)), + match ( + &self.behavior, + self.expression_kind_tracker.type_of_with_expr(&expr), + ) { + // avoid errors on unimplemented functionality if possible + ( + &Behavior::Wgsl(WgslRestrictions::Runtime(_)) + | &Behavior::Glsl(GlslRestrictions::Runtime(_)), + ExpressionKind::Const, + ) => match self.try_eval_and_append_impl(&expr, span) { + Err( + ConstantEvaluatorError::NotImplemented(_) + | ConstantEvaluatorError::InvalidBinaryOpArgs, + ) => Ok(self.append_expr(expr, span, ExpressionKind::Runtime)), + res => res, + }, + (_, ExpressionKind::Const) => self.try_eval_and_append_impl(&expr, span), + (&Behavior::Wgsl(WgslRestrictions::Const), ExpressionKind::Override) => { + Err(ConstantEvaluatorError::OverrideExpr) } - } else { - res + ( + &Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)), + ExpressionKind::Override, + ) => Ok(self.append_expr(expr, span, ExpressionKind::Override)), + (&Behavior::Glsl(_), ExpressionKind::Override) => unreachable!(), + ( + &Behavior::Wgsl(WgslRestrictions::Runtime(_)) + | &Behavior::Glsl(GlslRestrictions::Runtime(_)), + ExpressionKind::Runtime, + ) => Ok(self.append_expr(expr, span, ExpressionKind::Runtime)), + (_, ExpressionKind::Runtime) => Err(ConstantEvaluatorError::RuntimeExpr), + } + } + + /// Is the [`Self::expressions`] arena the global module expression arena? + const fn is_global_arena(&self) -> bool { + matches!( + self.behavior, + Behavior::Wgsl(WgslRestrictions::Const | WgslRestrictions::Override) + | Behavior::Glsl(GlslRestrictions::Const) + ) + } + + const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> { + match self.behavior { + Behavior::Wgsl(WgslRestrictions::Runtime(ref function_local_data)) + | Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => { + Some(function_local_data) + } + _ => None, } } @@ -625,14 +759,12 @@ impl<'a> ConstantEvaluator<'a> { ) -> Result, ConstantEvaluatorError> { log::trace!("try_eval_and_append: {:?}", expr); match *expr { - Expression::Constant(c) if self.function_local_data.is_none() => { + Expression::Constant(c) if self.is_global_arena() => { // "See through" the constant and use its initializer. // This is mainly done to avoid having constants pointing to other constants. Ok(self.constants[c].init) } - Expression::Override(_) => Err(ConstantEvaluatorError::NotImplemented( - "overrides are WIP".into(), - )), + Expression::Override(_) => Err(ConstantEvaluatorError::Override), Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => { self.register_evaluated_expr(expr.clone(), span) } @@ -713,8 +845,8 @@ impl<'a> ConstantEvaluator<'a> { format!("{fun:?} built-in function"), )), Expression::ArrayLength(expr) => match self.behavior { - Behavior::Wgsl => Err(ConstantEvaluatorError::ArrayLength), - Behavior::Glsl => { + Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength), + Behavior::Glsl(_) => { let expr = self.check_and_get(expr)?; self.array_length(expr, span) } @@ -1881,34 +2013,35 @@ impl<'a> ConstantEvaluator<'a> { crate::valid::check_literal_value(literal)?; } - Ok(self.append_expr(expr, span, true)) + Ok(self.append_expr(expr, span, ExpressionKind::Const)) } - fn append_expr(&mut self, expr: Expression, span: Span, is_const: bool) -> Handle { - if let Some(FunctionLocalData { - ref mut emitter, - ref mut block, - ref mut expression_constness, - .. - }) = self.function_local_data - { - let is_running = emitter.is_running(); - let needs_pre_emit = expr.needs_pre_emit(); - let h = if is_running && needs_pre_emit { - block.extend(emitter.finish(self.expressions)); - let h = self.expressions.append(expr, span); - emitter.start(self.expressions); - h - } else { - self.expressions.append(expr, span) - }; - if is_const { - expression_constness.insert(h); + fn append_expr( + &mut self, + expr: Expression, + span: Span, + expr_type: ExpressionKind, + ) -> Handle { + let h = match self.behavior { + Behavior::Wgsl(WgslRestrictions::Runtime(ref mut function_local_data)) + | Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => { + let is_running = function_local_data.emitter.is_running(); + let needs_pre_emit = expr.needs_pre_emit(); + if is_running && needs_pre_emit { + function_local_data + .block + .extend(function_local_data.emitter.finish(self.expressions)); + let h = self.expressions.append(expr, span); + function_local_data.emitter.start(self.expressions); + h + } else { + self.expressions.append(expr, span) + } } - h - } else { - self.expressions.append(expr, span) - } + _ => self.expressions.append(expr, span), + }; + self.expression_kind_tracker.insert(h, expr_type); + h } fn resolve_type( @@ -2062,7 +2195,7 @@ mod tests { UniqueArena, VectorSize, }; - use super::{Behavior, ConstantEvaluator}; + use super::{Behavior, ConstantEvaluator, ExpressionConstnessTracker, WgslRestrictions}; #[test] fn unary_op() { @@ -2143,13 +2276,15 @@ mod tests { expr: expr1, }; + let expression_kind_tracker = + &mut ExpressionConstnessTracker::from_arena(&const_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, overrides: &overrides, expressions: &mut const_expressions, - function_local_data: None, + expression_kind_tracker, }; let res1 = solver @@ -2228,13 +2363,15 @@ mod tests { convert: Some(crate::BOOL_WIDTH), }; + let expression_kind_tracker = + &mut ExpressionConstnessTracker::from_arena(&const_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, overrides: &overrides, expressions: &mut const_expressions, - function_local_data: None, + expression_kind_tracker, }; let res = solver @@ -2345,13 +2482,15 @@ mod tests { let base = const_expressions.append(Expression::Constant(h), Default::default()); + let expression_kind_tracker = + &mut ExpressionConstnessTracker::from_arena(&const_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, overrides: &overrides, expressions: &mut const_expressions, - function_local_data: None, + expression_kind_tracker, }; let root1 = Expression::AccessIndex { base, index: 1 }; @@ -2437,13 +2576,15 @@ mod tests { let h_expr = const_expressions.append(Expression::Constant(h), Default::default()); + let expression_kind_tracker = + &mut ExpressionConstnessTracker::from_arena(&const_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, overrides: &overrides, expressions: &mut const_expressions, - function_local_data: None, + expression_kind_tracker, }; let solved_compose = solver @@ -2518,13 +2659,15 @@ mod tests { let h_expr = const_expressions.append(Expression::Constant(h), Default::default()); + let expression_kind_tracker = + &mut ExpressionConstnessTracker::from_arena(&const_expressions); let mut solver = ConstantEvaluator { - behavior: Behavior::Wgsl, + behavior: Behavior::Wgsl(WgslRestrictions::Const), types: &mut types, constants: &constants, overrides: &overrides, expressions: &mut const_expressions, - function_local_data: None, + expression_kind_tracker, }; let solved_compose = solver diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 6dc677ff23..2db956ee0e 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -11,7 +11,7 @@ mod terminator; mod typifier; pub use constant_evaluator::{ - ConstantEvaluator, ConstantEvaluatorError, ExpressionConstnessTracker, + ConstantEvaluator, ConstantEvaluatorError, ExpressionConstnessTracker, ExpressionKind, }; pub use emitter::Emitter; pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError}; diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 84f57f6c8a..fbb4461e38 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -226,7 +226,7 @@ struct Sampling { sampler: GlobalOrArgument, } -#[derive(Debug)] +#[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct FunctionInfo { diff --git a/naga/src/valid/expression.rs b/naga/src/valid/expression.rs index 7b259d69f9..79180a0711 100644 --- a/naga/src/valid/expression.rs +++ b/naga/src/valid/expression.rs @@ -90,6 +90,8 @@ pub enum ExpressionError { sampler: bool, has_ref: bool, }, + #[error("Sample offset must be a const-expression")] + InvalidSampleOffsetExprType, #[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")] InvalidSampleOffset(crate::ImageDimension, Handle), #[error("Depth reference {0:?} is not a scalar float")] @@ -129,9 +131,10 @@ pub enum ExpressionError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum ConstExpressionError { - #[error("The expression is not a constant expression")] - NonConst, + #[error("The expression is not a constant or override expression")] + NonConstOrOverride, #[error(transparent)] Compose(#[from] super::ComposeError), #[error("Splatting {0:?} can't be done")] @@ -184,9 +187,14 @@ impl super::Validator { handle: Handle, gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, + global_expr_kind: &crate::proc::ExpressionConstnessTracker, ) -> Result<(), ConstExpressionError> { use crate::Expression as E; + if !global_expr_kind.is_const_or_override(handle) { + return Err(super::ConstExpressionError::NonConstOrOverride); + } + match gctx.const_expressions[handle] { E::Literal(literal) => { self.validate_literal(literal)?; @@ -203,12 +211,14 @@ impl super::Validator { crate::TypeInner::Scalar { .. } => {} _ => return Err(super::ConstExpressionError::InvalidSplatType(value)), }, - _ => return Err(super::ConstExpressionError::NonConst), + // the constant evaluator will report errors about override-expressions + _ => {} } Ok(()) } + #[allow(clippy::too_many_arguments)] pub(super) fn validate_expression( &self, root: Handle, @@ -217,6 +227,7 @@ impl super::Validator { module: &crate::Module, info: &FunctionInfo, mod_info: &ModuleInfo, + global_expr_kind: &crate::proc::ExpressionConstnessTracker, ) -> Result { use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti}; @@ -462,6 +473,10 @@ impl super::Validator { // check constant offset if let Some(const_expr) = offset { + if !global_expr_kind.is_const(const_expr) { + return Err(ExpressionError::InvalidSampleOffsetExprType); + } + match *mod_info[const_expr].inner_with(&module.types) { Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {} Ti::Vector { diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index f0ca22cbda..dfb7fbc6ee 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -927,7 +927,7 @@ impl super::Validator { var: &crate::LocalVariable, gctx: crate::proc::GlobalCtx, fun_info: &FunctionInfo, - expression_constness: &crate::proc::ExpressionConstnessTracker, + local_expr_kind: &crate::proc::ExpressionConstnessTracker, ) -> Result<(), LocalVariableError> { log::debug!("var {:?}", var); let type_info = self @@ -945,7 +945,7 @@ impl super::Validator { return Err(LocalVariableError::InitializerType); } - if !expression_constness.is_const(init) { + if !local_expr_kind.is_const(init) { return Err(LocalVariableError::NonConstInitializer); } } @@ -959,14 +959,14 @@ impl super::Validator { module: &crate::Module, mod_info: &ModuleInfo, entry_point: bool, + global_expr_kind: &crate::proc::ExpressionConstnessTracker, ) -> Result> { let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?; - let expression_constness = - crate::proc::ExpressionConstnessTracker::from_arena(&fun.expressions); + let local_expr_kind = crate::proc::ExpressionConstnessTracker::from_arena(&fun.expressions); for (var_handle, var) in fun.local_variables.iter() { - self.validate_local_var(var, module.to_ctx(), &info, &expression_constness) + self.validate_local_var(var, module.to_ctx(), &info, &local_expr_kind) .map_err(|source| { FunctionError::LocalVariable { handle: var_handle, @@ -1032,7 +1032,15 @@ impl super::Validator { self.valid_expression_set.insert(handle.index()); } if self.flags.contains(super::ValidationFlags::EXPRESSIONS) { - match self.validate_expression(handle, expr, fun, module, &info, mod_info) { + match self.validate_expression( + handle, + expr, + fun, + module, + &info, + mod_info, + global_expr_kind, + ) { Ok(stages) => info.available_stages &= stages, Err(source) => { return Err(FunctionError::Expression { handle, source } diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 0643b1c9f5..bcda98b294 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -592,6 +592,7 @@ impl From for ValidationError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum InvalidHandleError { #[error(transparent)] BadHandle(#[from] BadHandle), @@ -602,6 +603,7 @@ pub enum InvalidHandleError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] #[error( "{subject:?} of kind {subject_kind:?} depends on {depends_on:?} of kind {depends_on_kind}, \ which has not been processed yet" diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 84c8b09ddb..945af946bb 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -10,6 +10,7 @@ use bit_set::BitSet; const MAX_WORKGROUP_SIZE: u32 = 0x4000; #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum GlobalVariableError { #[error("Usage isn't compatible with address space {0:?}")] InvalidUsage(crate::AddressSpace), @@ -30,6 +31,8 @@ pub enum GlobalVariableError { Handle, #[source] Disalignment, ), + #[error("Initializer must be a const-expression")] + InitializerExprType, #[error("Initializer doesn't match the variable type")] InitializerType, #[error("Initializer can't be used with address space {0:?}")] @@ -39,6 +42,7 @@ pub enum GlobalVariableError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum VaryingError { #[error("The type {0:?} does not match the varying")] InvalidType(Handle), @@ -76,6 +80,7 @@ pub enum VaryingError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum EntryPointError { #[error("Multiple conflicting entry points")] Conflict, @@ -395,6 +400,7 @@ impl super::Validator { var: &crate::GlobalVariable, gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, + global_expr_kind: &crate::proc::ExpressionConstnessTracker, ) -> Result<(), GlobalVariableError> { use super::TypeFlags; @@ -523,6 +529,10 @@ impl super::Validator { } } + if !global_expr_kind.is_const(init) { + return Err(GlobalVariableError::InitializerExprType); + } + let decl_ty = &gctx.types[var.ty].inner; let init_ty = mod_info[init].inner_with(gctx.types); if !decl_ty.equivalent(init_ty, gctx.types) { @@ -538,6 +548,7 @@ impl super::Validator { ep: &crate::EntryPoint, module: &crate::Module, mod_info: &ModuleInfo, + global_expr_kind: &crate::proc::ExpressionConstnessTracker, ) -> Result> { if ep.early_depth_test.is_some() { let required = Capabilities::EARLY_DEPTH_TEST; @@ -566,7 +577,7 @@ impl super::Validator { } let mut info = self - .validate_function(&ep.function, module, mod_info, true) + .validate_function(&ep.function, module, mod_info, true, global_expr_kind) .map_err(WithSpan::into_other)?; { diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 311279478c..b4b2063775 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -12,7 +12,7 @@ mod r#type; use crate::{ arena::Handle, - proc::{LayoutError, Layouter, TypeResolution}, + proc::{ExpressionConstnessTracker, LayoutError, Layouter, TypeResolution}, FastHashSet, }; use bit_set::BitSet; @@ -131,7 +131,7 @@ bitflags::bitflags! { } } -#[derive(Debug)] +#[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct ModuleInfo { @@ -178,7 +178,10 @@ pub struct Validator { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum ConstantError { + #[error("Initializer must be a const-expression")] + InitializerExprType, #[error("The type doesn't match the constant")] InvalidType, #[error("The type is not constructible")] @@ -186,11 +189,14 @@ pub enum ConstantError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum OverrideError { #[error("Override name and ID are missing")] MissingNameAndID, #[error("Override ID must be unique")] DuplicateID, + #[error("Initializer must be a const-expression or override-expression")] + InitializerExprType, #[error("The type doesn't match the override")] InvalidType, #[error("The type is not constructible")] @@ -200,6 +206,7 @@ pub enum OverrideError { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum ValidationError { #[error(transparent)] InvalidHandle(#[from] InvalidHandleError), @@ -335,6 +342,7 @@ impl Validator { handle: Handle, gctx: crate::proc::GlobalCtx, mod_info: &ModuleInfo, + global_expr_kind: &ExpressionConstnessTracker, ) -> Result<(), ConstantError> { let con = &gctx.constants[handle]; @@ -343,6 +351,10 @@ impl Validator { return Err(ConstantError::NonConstructibleType); } + if !global_expr_kind.is_const(con.init) { + return Err(ConstantError::InitializerExprType); + } + let decl_ty = &gctx.types[con.ty].inner; let init_ty = mod_info[con.init].inner_with(gctx.types); if !decl_ty.equivalent(init_ty, gctx.types) { @@ -455,17 +467,24 @@ impl Validator { } } + let global_expr_kind = ExpressionConstnessTracker::from_arena(&module.const_expressions); + if self.flags.contains(ValidationFlags::CONSTANTS) { for (handle, _) in module.const_expressions.iter() { - self.validate_const_expression(handle, module.to_ctx(), &mod_info) - .map_err(|source| { - ValidationError::ConstExpression { handle, source } - .with_span_handle(handle, &module.const_expressions) - })? + self.validate_const_expression( + handle, + module.to_ctx(), + &mod_info, + &global_expr_kind, + ) + .map_err(|source| { + ValidationError::ConstExpression { handle, source } + .with_span_handle(handle, &module.const_expressions) + })? } for (handle, constant) in module.constants.iter() { - self.validate_constant(handle, module.to_ctx(), &mod_info) + self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind) .map_err(|source| { ValidationError::Constant { handle, @@ -490,7 +509,7 @@ impl Validator { } for (var_handle, var) in module.global_variables.iter() { - self.validate_global_var(var, module.to_ctx(), &mod_info) + self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind) .map_err(|source| { ValidationError::GlobalVariable { handle: var_handle, @@ -502,7 +521,7 @@ impl Validator { } for (handle, fun) in module.functions.iter() { - match self.validate_function(fun, module, &mod_info, false) { + match self.validate_function(fun, module, &mod_info, false, &global_expr_kind) { Ok(info) => mod_info.functions.push(info), Err(error) => { return Err(error.and_then(|source| { @@ -528,7 +547,7 @@ impl Validator { .with_span()); // TODO: keep some EP span information? } - match self.validate_entry_point(ep, module, &mod_info) { + match self.validate_entry_point(ep, module, &mod_info, &global_expr_kind) { Ok(info) => mod_info.entry_points.push(info), Err(error) => { return Err(error.and_then(|source| { diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index b8eb618ed4..03e87fd99b 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -63,6 +63,7 @@ bitflags::bitflags! { } #[derive(Clone, Copy, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum Disalignment { #[error("The array stride {stride} is not a multiple of the required alignment {alignment}")] ArrayStride { stride: u32, alignment: Alignment }, @@ -87,6 +88,7 @@ pub enum Disalignment { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum TypeError { #[error("Capability {0:?} is required")] MissingCapability(Capabilities), diff --git a/naga/tests/in/overrides.wgsl b/naga/tests/in/overrides.wgsl index b498a8b527..41e99f9426 100644 --- a/naga/tests/in/overrides.wgsl +++ b/naga/tests/in/overrides.wgsl @@ -6,7 +6,7 @@ override depth: f32; // Specified at the API level using // the name "depth". // Must be overridden. - // override height = 2 * depth; // The default value + override height = 2 * depth; // The default value // (if not set at the API level), // depends on another // overridable constant. diff --git a/naga/tests/out/analysis/overrides.info.ron b/naga/tests/out/analysis/overrides.info.ron index 481c3eac99..7a2447f3c0 100644 --- a/naga/tests/out/analysis/overrides.info.ron +++ b/naga/tests/out/analysis/overrides.info.ron @@ -33,6 +33,12 @@ kind: Float, width: 4, ))), + Handle(2), + Value(Scalar(( + kind: Float, + width: 4, + ))), + Handle(2), Value(Scalar(( kind: Float, width: 4, diff --git a/naga/tests/out/hlsl/overrides.hlsl b/naga/tests/out/hlsl/overrides.hlsl index 63b13a5d2b..0a849fd4db 100644 --- a/naga/tests/out/hlsl/overrides.hlsl +++ b/naga/tests/out/hlsl/overrides.hlsl @@ -3,6 +3,7 @@ static const float specular_param = 2.3; static const float gain = 1.1; static const float width = 0.0; static const float depth = 2.3; +static const float height = 4.6; static const float inferred_f32_ = 2.718; [numthreads(1, 1, 1)] diff --git a/naga/tests/out/ir/overrides.compact.ron b/naga/tests/out/ir/overrides.compact.ron index af4b31eba9..d15abbd033 100644 --- a/naga/tests/out/ir/overrides.compact.ron +++ b/naga/tests/out/ir/overrides.compact.ron @@ -52,11 +52,17 @@ ty: 2, init: None, ), + ( + name: Some("height"), + id: None, + ty: 2, + init: Some(6), + ), ( name: Some("inferred_f32"), id: None, ty: 2, - init: Some(4), + init: Some(7), ), ], global_variables: [], @@ -64,6 +70,13 @@ Literal(Bool(true)), Literal(F32(2.3)), Literal(F32(0.0)), + Override(5), + Literal(F32(2.0)), + Binary( + op: Multiply, + left: 5, + right: 4, + ), Literal(F32(2.718)), ], functions: [], diff --git a/naga/tests/out/ir/overrides.ron b/naga/tests/out/ir/overrides.ron index af4b31eba9..d15abbd033 100644 --- a/naga/tests/out/ir/overrides.ron +++ b/naga/tests/out/ir/overrides.ron @@ -52,11 +52,17 @@ ty: 2, init: None, ), + ( + name: Some("height"), + id: None, + ty: 2, + init: Some(6), + ), ( name: Some("inferred_f32"), id: None, ty: 2, - init: Some(4), + init: Some(7), ), ], global_variables: [], @@ -64,6 +70,13 @@ Literal(Bool(true)), Literal(F32(2.3)), Literal(F32(0.0)), + Override(5), + Literal(F32(2.0)), + Binary( + op: Multiply, + left: 5, + right: 4, + ), Literal(F32(2.718)), ], functions: [], diff --git a/naga/tests/out/msl/overrides.msl b/naga/tests/out/msl/overrides.msl index 419edd8904..13a3b623a0 100644 --- a/naga/tests/out/msl/overrides.msl +++ b/naga/tests/out/msl/overrides.msl @@ -9,6 +9,7 @@ constant float specular_param = 2.3; constant float gain = 1.1; constant float width = 0.0; constant float depth = 2.3; +constant float height = 4.6; constant float inferred_f32_ = 2.718; kernel void main_( diff --git a/naga/tests/out/spv/overrides.main.spvasm b/naga/tests/out/spv/overrides.main.spvasm index 7dfa6df3e5..7731edfb93 100644 --- a/naga/tests/out/spv/overrides.main.spvasm +++ b/naga/tests/out/spv/overrides.main.spvasm @@ -1,25 +1,27 @@ ; SPIR-V ; Version: 1.0 ; Generator: rspirv -; Bound: 15 +; Bound: 17 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint GLCompute %12 "main" -OpExecutionMode %12 LocalSize 1 1 1 +OpEntryPoint GLCompute %14 "main" +OpExecutionMode %14 LocalSize 1 1 1 %2 = OpTypeVoid %3 = OpTypeBool %4 = OpTypeFloat 32 %5 = OpConstantTrue %3 %6 = OpConstant %4 2.3 %7 = OpConstant %4 0.0 -%8 = OpConstant %4 2.718 -%9 = OpConstantFalse %3 -%10 = OpConstant %4 1.1 -%13 = OpTypeFunction %2 -%12 = OpFunction %2 None %13 -%11 = OpLabel -OpBranch %14 -%14 = OpLabel +%8 = OpConstantFalse %3 +%9 = OpConstant %4 1.1 +%10 = OpConstant %4 2.0 +%11 = OpConstant %4 4.6 +%12 = OpConstant %4 2.718 +%15 = OpTypeFunction %2 +%14 = OpFunction %2 None %15 +%13 = OpLabel +OpBranch %16 +%16 = OpLabel OpReturn OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/quad_glsl.vert.wgsl b/naga/tests/out/wgsl/quad_glsl.vert.wgsl index 8942e4c72f..0a3d7cecac 100644 --- a/naga/tests/out/wgsl/quad_glsl.vert.wgsl +++ b/naga/tests/out/wgsl/quad_glsl.vert.wgsl @@ -14,8 +14,8 @@ fn main_1() { let _e4 = a_uv_1; v_uv = _e4; let _e6 = a_pos_1; - let _e8 = (c_scale * _e6); - gl_Position = vec4(_e8.x, _e8.y, 0f, 1f); + let _e7 = (c_scale * _e6); + gl_Position = vec4(_e7.x, _e7.y, 0f, 1f); return; }