From 3ea6ca428f5ba91ae0b5ecb60a96976ff4c10988 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Fri, 20 Mar 2020 12:54:51 -0400 Subject: [PATCH] Move the global usage into the IR function --- src/back/msl.rs | 34 ++++++++++++++++------------------ src/front/spirv.rs | 17 ++++------------- src/front/wgsl.rs | 39 ++++++--------------------------------- src/lib.rs | 10 ++++++++-- src/proc/interface.rs | 28 +++++++++++----------------- src/proc/mod.rs | 1 - 6 files changed, 45 insertions(+), 84 deletions(-) diff --git a/src/back/msl.rs b/src/back/msl.rs index e4896380da..7b1d806a97 100644 --- a/src/back/msl.rs +++ b/src/back/msl.rs @@ -22,8 +22,7 @@ use std::{ use crate::{ arena::Handle, - proc::GlobalUse, - FastHashMap, FastHashSet + FastHashMap, }; /// Expect all the global variables to have a pointer type, @@ -226,7 +225,7 @@ impl AsName for Option { struct TypedGlobalVariable<'a> { module: &'a crate::Module, handle: crate::Handle, - usage: GlobalUse, + usage: crate::GlobalUse, } impl Display for TypedGlobalVariable<'_> { fn fmt(&self, formatter: &mut Formatter<'_>) -> Result<(), FmtError> { @@ -236,7 +235,7 @@ impl Display for TypedGlobalVariable<'_> { spirv::StorageClass::Uniform | spirv::StorageClass::UniformConstant | spirv::StorageClass::StorageBuffer => { - let space = if self.usage.contains(GlobalUse::STORE) { + let space = if self.usage.contains(crate::GlobalUse::STORE) { "device " } else { "constant " @@ -773,9 +772,8 @@ impl Writer { let fun_name = fun.name.or_index(fun_handle); // find the entry point(s) and inputs/outputs let mut exec_model = None; - let global_use = GlobalUse::scan(fun, &module.global_variables); let mut last_used_global = None; - for ((handle, var), &usage) in module.global_variables.iter().zip(global_use.iter()) { + for ((handle, var), &usage) in module.global_variables.iter().zip(&fun.global_usage) { match var.class { spirv::StorageClass::Input => { if let Some(crate::Binding::Location(_)) = var.binding { @@ -789,12 +787,8 @@ impl Writer { last_used_global = Some(handle); } } - let mut var_inputs = FastHashSet::default(); - let mut var_outputs = FastHashSet::default(); for ep in module.entry_points.iter() { - if ep.function == fun_handle{ - var_inputs.extend(ep.inputs.iter().cloned()); - var_outputs.extend(ep.outputs.iter().cloned()); + if ep.function == fun_handle { if exec_model.is_some() { if exec_model != Some(ep.exec_model) { return Err(Error::MixedExecutionModels(fun_handle)); @@ -819,8 +813,10 @@ impl Writer { if em != spirv::ExecutionModel::GLCompute { writeln!(self.out, "struct {} {{", location_input_name)?; - for &handle in var_inputs.iter() { - let var = &module.global_variables[handle]; + for ((handle, var), &usage) in module.global_variables.iter().zip(&fun.global_usage) { + if var.class != spirv::StorageClass::Input || !usage.contains(crate::GlobalUse::LOAD) { + continue + } // if it's a struct, lift all the built-in contents up to the root let mut ty_handle = var.ty; if GLOBAL_POINTERS { @@ -839,7 +835,7 @@ impl Writer { } } else { if let Some(ref binding@crate::Binding::Location(_)) = var.binding { - let tyvar = TypedGlobalVariable { module, handle, usage: GlobalUse::empty() }; + let tyvar = TypedGlobalVariable { module, handle, usage: crate::GlobalUse::empty() }; let resolved = options.resolve_binding(binding, in_mode)?; writeln!(self.out, "\t{} [[{}]];", tyvar, resolved)?; } @@ -847,8 +843,10 @@ impl Writer { } writeln!(self.out, "}};")?; writeln!(self.out, "struct {} {{", output_name)?; - for &handle in var_outputs.iter() { - let var = &module.global_variables[handle]; + for ((handle, var), &usage) in module.global_variables.iter().zip(&fun.global_usage) { + if var.class != spirv::StorageClass::Output || !usage.contains(crate::GlobalUse::STORE) { + continue + } // if it's a struct, lift all the built-in contents up to the root let mut ty_handle = var.ty; if GLOBAL_POINTERS { @@ -867,7 +865,7 @@ impl Writer { writeln!(self.out, "\t{} {} [[{}]];", ty_name, name, resolved)?; } } else { - let tyvar = TypedGlobalVariable { module, handle, usage: GlobalUse::empty() }; + let tyvar = TypedGlobalVariable { module, handle, usage: crate::GlobalUse::empty() }; write!(self.out, "\t{}", tyvar)?; if let Some(ref binding) = var.binding { let resolved = options.resolve_binding(binding, out_mode)?; @@ -885,7 +883,7 @@ impl Writer { writeln!(self.out, "{} void {}(", em_str, fun_name)?; } - for ((handle, var), &usage) in module.global_variables.iter().zip(global_use.iter()) { + for ((handle, var), &usage) in module.global_variables.iter().zip(&fun.global_usage) { if usage.is_empty() || var.class == spirv::StorageClass::Output { continue } diff --git a/src/front/spirv.rs b/src/front/spirv.rs index 7ee620bf81..6e3d54e9ff 100644 --- a/src/front/spirv.rs +++ b/src/front/spirv.rs @@ -732,22 +732,11 @@ impl> Parser { module.entry_points.reserve(entry_points.len()); for raw in entry_points { - let mut ep = crate::EntryPoint { + module.entry_points.push(crate::EntryPoint { exec_model: raw.exec_model, name: raw.name, function: *self.lookup_function.lookup(raw.function_id)?, - inputs: FastHashSet::default(), - outputs: FastHashSet::default(), - }; - for var_id in raw.variable_ids { - let handle = self.lookup_variable.lookup(var_id)?.handle; - match module.global_variables[handle].class { - spirv::StorageClass::Input => ep.inputs.insert(handle), - spirv::StorageClass::Output => ep.outputs.insert(handle), - other => return Err(Error::InvalidVariableClass(other)), - }; - } - module.entry_points.push(ep); + }); } Ok(module) @@ -1405,6 +1394,7 @@ impl> Parser { } else { Some(self.lookup_type.lookup(result_type)?.handle) }, + global_usage: Vec::new(), local_variables: Arena::new(), expressions: self.make_expression_storage(), body: Vec::new(), @@ -1447,6 +1437,7 @@ impl> Parser { } } // done + fun.global_usage = crate::GlobalUse::scan(&fun.expressions, &fun.body, &module.global_variables); let handle = module.functions.append(fun); self.lookup_function.insert(fun_id, handle); self.lookup_expression.clear(); diff --git a/src/front/wgsl.rs b/src/front/wgsl.rs index b05bf82752..7a963b68bc 100644 --- a/src/front/wgsl.rs +++ b/src/front/wgsl.rs @@ -1,7 +1,7 @@ use crate::{ arena::{Arena, Handle}, - proc::{GlobalUse, Typifier, ResolveError}, - FastHashMap, FastHashSet, + proc::{Typifier, ResolveError}, + FastHashMap, }; @@ -1205,12 +1205,15 @@ impl Parser { global_vars: &module.global_variables, })?; // done + let global_usage = crate::GlobalUse::scan(&expressions, &body, &module.global_variables); self.scopes.pop(); + let fun = crate::Function { name: Some(fun_name.to_owned()), control: spirv::FunctionControl::empty(), parameter_types, return_type, + global_usage, local_variables, expressions, body, @@ -1341,43 +1344,13 @@ impl Parser { lexer.expect(Token::Operation('='))?; let fun_ident = lexer.next_ident()?; lexer.expect(Token::Separator(';'))?; - let (fun_handle, function) = module.functions + let (fun_handle, _) = module.functions .iter() .find(|(_, fun)| fun.name.as_ref().map(|s| s.as_str()) == Some(fun_ident)) .ok_or(Error::UnknownFunction(fun_ident))?; - - let uses = GlobalUse::scan(function, &module.global_variables); - let mut inputs = FastHashSet::default(); - let mut outputs = FastHashSet::default(); - for ((handle, var), &usage) in module.global_variables.iter().zip(uses.iter()) { - match var.class { - _ if usage.is_empty() => {} - spirv::StorageClass::Input if usage.contains(GlobalUse::LOAD) => { - inputs.insert(handle); - } - spirv::StorageClass::Output if usage.contains(GlobalUse::STORE) => { - outputs.insert(handle); - } - spirv::StorageClass::Input | - spirv::StorageClass::Output => { - let name = lookup_global_expression - .iter() - .find(|&(_, h)| match *h { - crate::Expression::GlobalVariable(h) => h == handle, - _ => false, - }) - .map(|(name, _)| name) - .unwrap(); - return Err(Error::MutabilityViolation(name)); - } - _ => {} - } - } module.entry_points.push(crate::EntryPoint { exec_model, name: export_name.unwrap_or(fun_ident).to_owned(), - inputs, - outputs, function: fun_handle, }); } diff --git a/src/lib.rs b/src/lib.rs index f7b091958d..9260a91e77 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -117,6 +117,13 @@ pub enum Binding { Descriptor { set: spirv::Word, binding: spirv::Word }, } +bitflags::bitflags! { + pub struct GlobalUse: u8 { + const LOAD = 0x1; + const STORE = 0x2; + } +} + #[derive(Clone, Debug)] pub struct GlobalVariable { pub name: Option, @@ -271,6 +278,7 @@ pub struct Function { pub control: spirv::FunctionControl, pub parameter_types: Vec>, pub return_type: Option>, + pub global_usage: Vec, pub local_variables: Arena, pub expressions: Arena, pub body: Block, @@ -280,8 +288,6 @@ pub struct Function { pub struct EntryPoint { pub exec_model: spirv::ExecutionModel, pub name: String, - pub inputs: FastHashSet>, - pub outputs: FastHashSet>, pub function: Handle, } diff --git a/src/proc/interface.rs b/src/proc/interface.rs index 2bea51f417..9522ac981b 100644 --- a/src/proc/interface.rs +++ b/src/proc/interface.rs @@ -2,16 +2,9 @@ use crate::{ arena::{Arena, Handle}, }; -bitflags::bitflags! { - pub struct GlobalUse: u8 { - const LOAD = 0x1; - const STORE = 0x2; - } -} - struct Interface<'a> { expressions: &'a Arena, - uses: Vec, + uses: Vec, } impl<'a> Interface<'a> { @@ -33,7 +26,7 @@ impl<'a> Interface<'a> { } E::FunctionParameter(_) => {}, E::GlobalVariable(handle) => { - self.uses[handle.index()] |= GlobalUse::LOAD; + self.uses[handle.index()] |= crate::GlobalUse::LOAD; } E::LocalVariable(_) => {} E::Load { pointer } => { @@ -117,7 +110,7 @@ impl<'a> Interface<'a> { left = base; } crate::Expression::GlobalVariable(handle) => { - self.uses[handle.index()] |= GlobalUse::STORE; + self.uses[handle.index()] |= crate::GlobalUse::STORE; break; } _ => break, @@ -130,16 +123,17 @@ impl<'a> Interface<'a> { } } -impl GlobalUse { +impl crate::GlobalUse { pub fn scan( - fun: &crate::Function, + expressions: &Arena, + body: &[crate::Statement], globals: &Arena, - ) -> Box<[Self]> { + ) -> Vec { let mut io = Interface { - expressions: &fun.expressions, - uses: vec![GlobalUse::empty(); globals.len()], + expressions, + uses: vec![crate::GlobalUse::empty(); globals.len()], }; - io.collect(&fun.body); - io.uses.into_boxed_slice() + io.collect(body); + io.uses } } diff --git a/src/proc/mod.rs b/src/proc/mod.rs index 9aba4f1cf0..cb78f765fd 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -1,5 +1,4 @@ mod interface; mod typifier; -pub use interface::GlobalUse; pub use typifier::{ResolveError, Typifier};