From d0c84a5ffac3b2d4d3fe1101864006f0bbc17d50 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Sat, 6 Feb 2021 02:20:54 -0500 Subject: [PATCH] [spv-in] support for void function calls --- src/arena.rs | 10 +++ src/front/spv/function.rs | 39 ++++----- src/front/spv/mod.rs | 161 +++++++++++++++++++++++--------------- src/lib.rs | 2 +- 4 files changed, 125 insertions(+), 87 deletions(-) diff --git a/src/arena.rs b/src/arena.rs index 0f673fa684..347dfe206c 100644 --- a/src/arena.rs +++ b/src/arena.rs @@ -126,6 +126,16 @@ impl Arena { }) } + /// Returns a iterator over the items stored in this arena, + /// returning both the item's handle and a mutable reference to it. + pub fn iter_mut(&mut self) -> impl DoubleEndedIterator, &mut T)> { + self.data.iter_mut().enumerate().map(|(i, v)| { + let position = i + 1; + let index = unsafe { Index::new_unchecked(position as u32) }; + (Handle::new(index), v) + }) + } + /// Adds a new value to the arena, returning a typed handle. /// /// The value is not linked to any SPIR-V module. diff --git a/src/front/spv/function.rs b/src/front/spv/function.rs index 654160b303..78e87f677d 100644 --- a/src/front/spv/function.rs +++ b/src/front/spv/function.rs @@ -57,23 +57,23 @@ impl> super::Parser { ) -> Result<(), Error> { self.switch(ModuleState::Function, inst.op)?; inst.expect(5)?; - let result_type = self.next()?; + let result_type_id = self.next()?; let fun_id = self.next()?; let _fun_control = self.next()?; - let fun_type = self.next()?; + let fun_type_id = self.next()?; let mut fun = { - let ft = self.lookup_function_type.lookup(fun_type)?; - if ft.return_type_id != result_type { - return Err(Error::WrongFunctionResultType(result_type)); + let ft = self.lookup_function_type.lookup(fun_type_id)?; + if ft.return_type_id != result_type_id { + return Err(Error::WrongFunctionResultType(result_type_id)); } crate::Function { name: self.future_decor.remove(&fun_id).and_then(|dec| dec.name), arguments: Vec::with_capacity(ft.parameter_type_ids.len()), - return_type: if self.lookup_void_type.contains(&result_type) { + return_type: if self.lookup_void_type == Some(result_type_id) { None } else { - Some(self.lookup_type.lookup(result_type)?.handle) + Some(self.lookup_type.lookup(result_type_id)?.handle) }, global_usage: Vec::new(), local_variables: Arena::new(), @@ -101,7 +101,7 @@ impl> super::Parser { if type_id != self .lookup_function_type - .lookup(fun_type)? + .lookup(fun_type_id)? .parameter_type_ids[i] { return Err(Error::WrongFunctionArgumentType(type_id)); @@ -116,7 +116,6 @@ impl> super::Parser { // Read body let mut flow_graph = FlowGraph::new(); - let base_deferred_call_index = self.deferred_function_calls.len(); // Scan the blocks and add them as nodes loop { @@ -156,39 +155,29 @@ impl> super::Parser { // done fun.fill_global_use(&module.global_variables); - let source = match self.lookup_entry_point.remove(&fun_id) { + let dump_suffix = match self.lookup_entry_point.remove(&fun_id) { Some(ep) => { + let dump_name = format!("flow.{:?}-{}.dot", ep.stage, ep.name); module.entry_points.insert( - (ep.stage, ep.name.clone()), + (ep.stage, ep.name), crate::EntryPoint { early_depth_test: ep.early_depth_test, workgroup_size: ep.workgroup_size, function: fun, }, ); - DeferredSource::EntryPoint(ep.stage, ep.name) + dump_name } None => { let handle = module.functions.append(fun); self.lookup_function.insert(fun_id, handle); - DeferredSource::Function(handle) + format!("flow.Fun-{}.dot", handle.index()) } }; - for dfc in self.deferred_function_calls[base_deferred_call_index..].iter_mut() { - dfc.source = source.clone(); - } - if let Some(ref prefix) = self.options.flow_graph_dump_prefix { let dump = flow_graph.to_graphviz().unwrap_or_default(); - let suffix = match source { - DeferredSource::Undefined => unreachable!(), - DeferredSource::EntryPoint(stage, ref name) => { - format!("flow.{:?}-{}.dot", stage, name) - } - DeferredSource::Function(handle) => format!("flow.Fun-{}.dot", handle.index()), - }; - let _ = std::fs::write(prefix.join(suffix), dump); + let _ = std::fs::write(prefix.join(dump_suffix), dump); } self.lookup_expression.clear(); diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index 783245cb39..0652af078d 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -268,18 +268,6 @@ struct LookupSampledImage { image: Handle, sampler: Handle, } -#[derive(Clone, Debug)] -enum DeferredSource { - Undefined, - EntryPoint(crate::ShaderStage, String), - Function(Handle), -} -struct DeferredFunctionCall { - source: DeferredSource, - expr_handle: Handle, - dst_id: spirv::Word, - arguments: Vec>, -} #[derive(Clone, Debug)] pub struct Assignment { @@ -302,7 +290,7 @@ pub struct Parser { lookup_member_type_id: FastHashMap<(Handle, MemberIndex), spirv::Word>, handle_sampling: FastHashMap, SamplingFlags>, lookup_type: FastHashMap, - lookup_void_type: FastHashSet, + lookup_void_type: Option, lookup_storage_buffer_types: FastHashSet>, // Lookup for samplers and sampled images, storing flags on how they are used. lookup_constant: FastHashMap, @@ -312,7 +300,9 @@ pub struct Parser { lookup_function_type: FastHashMap, lookup_function: FastHashMap>, lookup_entry_point: FastHashMap, - deferred_function_calls: Vec, + //Note: the key here is fully artificial, has nothing to do with the module + deferred_function_calls: FastHashMap, spirv::Word>, + dummy_functions: Arena, options: Options, } @@ -328,7 +318,7 @@ impl> Parser { handle_sampling: FastHashMap::default(), lookup_member_type_id: FastHashMap::default(), lookup_type: FastHashMap::default(), - lookup_void_type: FastHashSet::default(), + lookup_void_type: None, lookup_storage_buffer_types: FastHashSet::default(), lookup_constant: FastHashMap::default(), lookup_variable: FastHashMap::default(), @@ -337,7 +327,8 @@ impl> Parser { lookup_function_type: FastHashMap::default(), lookup_function: FastHashMap::default(), lookup_entry_point: FastHashMap::default(), - deferred_function_calls: Vec::new(), + deferred_function_calls: FastHashMap::default(), + dummy_functions: Arena::new(), options: options.clone(), } } @@ -510,7 +501,7 @@ impl> Parser { const_arena: &Arena, global_arena: &Arena, ) -> Result { - let mut assignments = Vec::new(); + let mut block = Vec::new(); let mut phis = Vec::new(); let mut merge = None; let terminator = loop { @@ -782,8 +773,8 @@ impl> Parser { } let base_expr = self.lookup_expression.lookup(pointer_id)?; let value_expr = self.lookup_expression.lookup(value_id)?; - assignments.push(Assignment { - to: base_expr.handle, + block.push(crate::Statement::Store { + pointer: base_expr.handle, value: value_expr.handle, }); } @@ -1315,22 +1306,29 @@ impl> Parser { let arg_id = self.next()?; arguments.push(self.lookup_expression.lookup(arg_id)?.handle); } - // will be replaced by the actual expression - let expr = crate::Expression::FunctionArgument(!0); - let expr_handle = expressions.append(expr); - self.deferred_function_calls.push(DeferredFunctionCall { - source: DeferredSource::Undefined, - expr_handle, - dst_id: func_id, - arguments, - }); - self.lookup_expression.insert( - result_id, - LookupExpression { - handle: expr_handle, - type_id: result_type_id, - }, - ); + + // We just need an unique handle here, nothing more. + let function = self.dummy_functions.append(crate::Function::default()); + self.deferred_function_calls.insert(function, func_id); + + if self.lookup_void_type == Some(result_type_id) { + block.push(crate::Statement::Call { + function, + arguments, + }); + } else { + let expr_handle = expressions.append(crate::Expression::Call { + function, + arguments, + }); + self.lookup_expression.insert( + result_id, + LookupExpression { + handle: expr_handle, + type_id: result_type_id, + }, + ); + } } Op::ExtInst => { use crate::MathFunction as Mf; @@ -1551,14 +1549,6 @@ impl> Parser { } }; - let mut block = Vec::new(); - for assignment in assignments.iter() { - block.push(crate::Statement::Store { - pointer: assignment.to, - value: assignment.value, - }); - } - Ok(ControlFlowNode { id: block_id, ty: None, @@ -1610,6 +1600,67 @@ impl> Parser { } } + fn patch_function_call_statements( + &self, + statements: &mut [crate::Statement], + ) -> Result<(), Error> { + use crate::Statement as S; + for statement in statements.iter_mut() { + match *statement { + S::Block(ref mut block) => { + self.patch_function_call_statements(block)?; + } + S::If { + condition: _, + ref mut accept, + ref mut reject, + } => { + self.patch_function_call_statements(accept)?; + self.patch_function_call_statements(reject)?; + } + S::Switch { + selector: _, + ref mut cases, + ref mut default, + } => { + for case in cases.iter_mut() { + self.patch_function_call_statements(&mut case.body)?; + } + self.patch_function_call_statements(default)?; + } + S::Loop { + ref mut body, + ref mut continuing, + } => { + self.patch_function_call_statements(body)?; + self.patch_function_call_statements(continuing)?; + } + S::Break | S::Continue | S::Return { .. } | S::Kill | S::Store { .. } => {} + S::Call { + ref mut function, .. + } => { + let fun_id = self.deferred_function_calls[function]; + *function = *self.lookup_function.lookup(fun_id)?; + } + } + } + Ok(()) + } + + fn patch_function_calls(&self, fun: &mut crate::Function) -> Result<(), Error> { + for (_, expr) in fun.expressions.iter_mut() { + if let crate::Expression::Call { + ref mut function, .. + } = *expr + { + let fun_id = self.deferred_function_calls[function]; + *function = *self.lookup_function.lookup(fun_id)?; + } + } + self.patch_function_call_statements(&mut fun.body)?; + Ok(()) + } + pub fn parse(mut self) -> Result { let mut module = { if self.next()? != spirv::MAGIC_NUMBER { @@ -1680,23 +1731,11 @@ impl> Parser { } } - for dfc in self.deferred_function_calls.drain(..) { - let dst_handle = *self.lookup_function.lookup(dfc.dst_id)?; - let fun = match dfc.source { - DeferredSource::Undefined => unreachable!(), - DeferredSource::Function(fun_handle) => module.functions.get_mut(fun_handle), - DeferredSource::EntryPoint(stage, name) => { - &mut module - .entry_points - .get_mut(&(stage, name)) - .unwrap() - .function - } - }; - *fun.expressions.get_mut(dfc.expr_handle) = crate::Expression::Call { - function: dst_handle, - arguments: dfc.arguments, - }; + for (_, func) in module.functions.iter_mut() { + self.patch_function_calls(func)?; + } + for (_, ep) in module.entry_points.iter_mut() { + self.patch_function_calls(&mut ep.function)?; } if !self.future_decor.is_empty() { @@ -1912,7 +1951,7 @@ impl> Parser { self.switch(ModuleState::Type, inst.op)?; inst.expect(2)?; let id = self.next()?; - self.lookup_void_type.insert(id); + self.lookup_void_type = Some(id); Ok(()) } diff --git a/src/lib.rs b/src/lib.rs index 92058c9032..abad73cb4e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -821,7 +821,7 @@ pub struct FunctionArgument { } /// A function defined in the module. -#[derive(Debug)] +#[derive(Debug, Default)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub struct Function {