diff --git a/src/front/glsl/context.rs b/src/front/glsl/context.rs index 0f836def69..0fe0f734a3 100644 --- a/src/front/glsl/context.rs +++ b/src/front/glsl/context.rs @@ -245,19 +245,6 @@ impl<'a> Context<'a> { } pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result> { - let mut append = |arena: &mut Arena, expr: Expression, span| { - let is_running = self.emitter.is_running(); - let needs_pre_emit = expr.needs_pre_emit(); - if is_running && needs_pre_emit { - self.body.extend(self.emitter.finish(arena)); - } - let h = arena.append(expr, span); - if is_running && needs_pre_emit { - self.emitter.start(arena); - } - h - }; - let (expressions, const_expressions) = if self.is_const { (&mut self.module.const_expressions, None) } else { @@ -269,7 +256,10 @@ impl<'a> Context<'a> { constants: &self.module.constants, expressions, const_expressions, - append: (!self.is_const).then_some(&mut append), + emitter: (!self.is_const).then_some(crate::proc::ConstantEvaluatorEmitter { + emitter: &mut self.emitter, + block: &mut self.body, + }), }; let res = eval.try_eval_and_append(&expr, meta).map_err(|e| Error { @@ -280,7 +270,17 @@ impl<'a> Context<'a> { match res { Ok(expr) => Ok(expr), Err(e) if self.is_const => Err(e), - Err(_) => Ok(append(&mut self.expressions, expr, meta)), + Err(_) => { + let needs_pre_emit = expr.needs_pre_emit(); + if needs_pre_emit { + self.body.extend(self.emitter.finish(expressions)); + } + let h = expressions.append(expr, meta); + if needs_pre_emit { + self.emitter.start(expressions); + } + Ok(h) + } } } diff --git a/src/front/wgsl/lower/mod.rs b/src/front/wgsl/lower/mod.rs index 905eb55ea5..f36840d3f1 100644 --- a/src/front/wgsl/lower/mod.rs +++ b/src/front/wgsl/lower/mod.rs @@ -6,8 +6,8 @@ use crate::front::wgsl::parse::number::Number; use crate::front::wgsl::parse::{ast, conv}; use crate::front::Typifier; use crate::proc::{ - ensure_block_returns, Alignment, ConstantEvaluator, Emitter, Layouter, ResolveContext, - TypeResolution, + ensure_block_returns, Alignment, ConstantEvaluator, ConstantEvaluatorEmitter, Emitter, + Layouter, ResolveContext, TypeResolution, }; use crate::{Arena, FastHashMap, FastIndexMap, Handle, Span}; @@ -338,20 +338,10 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { constants: &self.module.constants, expressions: rctx.naga_expressions, const_expressions: Some(&self.module.const_expressions), - append: Some( - |arena: &mut Arena, expr: crate::Expression, span| { - let is_running = rctx.emitter.is_running(); - let needs_pre_emit = expr.needs_pre_emit(); - if is_running && needs_pre_emit { - rctx.block.extend(rctx.emitter.finish(arena)); - } - let h = arena.append(expr, span); - if is_running && needs_pre_emit { - rctx.emitter.start(arena); - } - h - }, - ), + emitter: Some(ConstantEvaluatorEmitter { + emitter: rctx.emitter, + block: rctx.block, + }), }; match eval.try_eval_and_append(&expr, span) { @@ -365,15 +355,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> { constants: &self.module.constants, expressions: &mut self.module.const_expressions, const_expressions: None, - append: None::< - Box< - dyn FnMut( - &mut Arena, - crate::Expression, - Span, - ) -> Handle, - >, - >, + emitter: None, }; eval.try_eval_and_append(&expr, span) diff --git a/src/proc/constant_evaluator.rs b/src/proc/constant_evaluator.rs index 218252e1f9..2580cadf2d 100644 --- a/src/proc/constant_evaluator.rs +++ b/src/proc/constant_evaluator.rs @@ -5,15 +5,22 @@ use crate::{ }; #[derive(Debug)] -pub struct ConstantEvaluator< - 'a, - F: FnMut(&mut Arena, Expression, Span) -> Handle, -> { +pub struct ConstantEvaluator<'a> { pub types: &'a mut UniqueArena, pub constants: &'a Arena, pub expressions: &'a mut Arena, pub const_expressions: Option<&'a Arena>, - pub append: Option, + + /// When `expressions` refers to a function's local expression + /// arena, this is the emitter we should interrupt when inserting + /// new things into it. + pub emitter: Option>, +} + +#[derive(Debug)] +pub struct ConstantEvaluatorEmitter<'a> { + pub emitter: &'a mut super::Emitter, + pub block: &'a mut crate::Block, } #[derive(Clone, Debug, PartialEq, thiserror::Error)] @@ -99,9 +106,7 @@ impl Arena { } } -impl<'a, F: FnMut(&mut Arena, Expression, Span) -> Handle> - ConstantEvaluator<'a, F> -{ +impl ConstantEvaluator<'_> { fn check_and_get( &mut self, expr: Handle, @@ -800,11 +805,20 @@ impl<'a, F: FnMut(&mut Arena, Expression, Span) -> Handle Handle { - if let Some(ref mut append) = self.append { - append(self.expressions, expr, span) - } else { - self.expressions.append(expr, span) + if let Some(ref mut emitter) = self.emitter { + let is_running = emitter.emitter.is_running(); + let needs_pre_emit = expr.needs_pre_emit(); + if is_running && needs_pre_emit { + emitter + .block + .extend(emitter.emitter.finish(self.expressions)); + let h = self.expressions.append(expr, span); + emitter.emitter.start(self.expressions); + return h; + } } + + self.expressions.append(expr, span) } } @@ -973,15 +987,7 @@ mod tests { constants: &constants, expressions: &mut const_expressions, const_expressions: None, - append: None::< - Box< - dyn FnMut( - &mut Arena, - Expression, - crate::Span, - ) -> crate::Handle, - >, - >, + emitter: None, }; let res1 = solver @@ -1068,15 +1074,7 @@ mod tests { constants: &constants, expressions: &mut const_expressions, const_expressions: None, - append: None::< - Box< - dyn FnMut( - &mut Arena, - Expression, - crate::Span, - ) -> crate::Handle, - >, - >, + emitter: None, }; let res = solver @@ -1195,15 +1193,7 @@ mod tests { constants: &constants, expressions: &mut const_expressions, const_expressions: None, - append: None::< - Box< - dyn FnMut( - &mut Arena, - Expression, - crate::Span, - ) -> crate::Handle, - >, - >, + emitter: None, }; let root1 = Expression::AccessIndex { base, index: 1 }; diff --git a/src/proc/mod.rs b/src/proc/mod.rs index bbdd75eb32..2282520a3a 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -10,7 +10,7 @@ mod namer; mod terminator; mod typifier; -pub use constant_evaluator::{ConstantEvaluator, ConstantEvaluatorError}; +pub use constant_evaluator::{ConstantEvaluator, ConstantEvaluatorEmitter, ConstantEvaluatorError}; pub use emitter::Emitter; pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError}; pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};