From 027634451de050adf786bb48340742e2c41bd362 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Capucho?= Date: Tue, 1 Jun 2021 19:06:14 +0100 Subject: [PATCH] [glsl-in] Collect entry point arguments usage --- src/front/glsl/ast.rs | 42 ++++++++-- src/front/glsl/functions.rs | 148 ++++++++++++++++++++++-------------- src/front/glsl/parser.rs | 30 +++++--- src/front/glsl/variables.rs | 36 ++++++--- 4 files changed, 176 insertions(+), 80 deletions(-) diff --git a/src/front/glsl/ast.rs b/src/front/glsl/ast.rs index a32edde144..491af32c66 100644 --- a/src/front/glsl/ast.rs +++ b/src/front/glsl/ast.rs @@ -12,11 +12,17 @@ use crate::{ }; #[derive(Debug, Clone, Copy)] -pub enum GlobalLookup { +pub enum GlobalLookupKind { Variable(Handle), BlockSelect(Handle, u32), } +#[derive(Debug, Clone, Copy)] +pub struct GlobalLookup { + pub kind: GlobalLookupKind, + pub entry_arg: Option, +} + #[derive(Debug, PartialEq, Eq, Hash)] pub struct FunctionSignature { pub name: String, @@ -33,6 +39,13 @@ pub struct FunctionDeclaration { pub void: bool, } +bitflags::bitflags! { + pub struct EntryArgUse: u32 { + const READ = 0x1; + const WRITE = 0x2; + } +} + #[derive(Debug)] pub struct Program<'a> { pub version: u16, @@ -45,8 +58,10 @@ pub struct Program<'a> { pub global_variables: Vec<(String, GlobalLookup)>, pub constants: Vec<(String, Handle)>, - pub entry_args: Vec<(Binding, bool, Handle)>, + pub entry_args: Vec<(Binding, Handle)>, pub entries: Vec<(String, ShaderStage, Handle)>, + // TODO: More efficient representation + pub function_arg_use: Vec>, pub module: Module, } @@ -65,6 +80,7 @@ impl<'a> Program<'a> { entry_args: Vec::new(), entries: Vec::new(), + function_arg_use: Vec::new(), module: Module::default(), } @@ -150,6 +166,7 @@ pub struct Context<'function> { expressions: &'function mut Arena, pub locals: &'function mut Arena, pub arguments: &'function mut Vec, + pub arg_use: Vec, //TODO: Find less allocation heavy representation pub scopes: Vec>, @@ -173,6 +190,7 @@ impl<'function> Context<'function> { expressions, locals, arguments, + arg_use: vec![EntryArgUse::empty(); program.entry_args.len()], scopes: vec![FastHashMap::default()], lookup_global_var_exps: FastHashMap::with_capacity_and_hasher( @@ -194,6 +212,7 @@ impl<'function> Context<'function> { expr, load: None, mutable: false, + entry_arg: None, }; this.lookup_global_var_exps.insert(name.into(), var); @@ -201,12 +220,13 @@ impl<'function> Context<'function> { for &(ref name, lookup) in program.global_variables.iter() { this.emit_flush(body); - let (expr, load) = match lookup { - GlobalLookup::Variable(v) => ( + let GlobalLookup { kind, entry_arg } = lookup; + let (expr, load) = match kind { + GlobalLookupKind::Variable(v) => ( Expression::GlobalVariable(v), program.module.global_variables[v].class != StorageClass::Handle, ), - GlobalLookup::BlockSelect(handle, index) => { + GlobalLookupKind::BlockSelect(handle, index) => { let base = this.expressions.append(Expression::GlobalVariable(handle)); (Expression::AccessIndex { base, index }, true) @@ -225,6 +245,7 @@ impl<'function> Context<'function> { }, // TODO: respect constant qualifier mutable: true, + entry_arg, }; this.lookup_global_var_exps.insert(name.into(), var); @@ -285,6 +306,7 @@ impl<'function> Context<'function> { expr, load: Some(load), mutable, + entry_arg: None, }, ); } @@ -336,6 +358,7 @@ impl<'function> Context<'function> { expr, load, mutable, + entry_arg: None, }, ); } @@ -455,8 +478,16 @@ impl<'function> Context<'function> { )); } + if let Some(idx) = var.entry_arg { + self.arg_use[idx] |= EntryArgUse::WRITE + } + var.expr } else { + if let Some(idx) = var.entry_arg { + self.arg_use[idx] |= EntryArgUse::READ + } + var.load.unwrap_or(var.expr) } } @@ -617,6 +648,7 @@ pub struct VariableReference { pub expr: Handle, pub load: Option>, pub mutable: bool, + pub entry_arg: Option, } #[derive(Debug, Clone)] diff --git a/src/front/glsl/functions.rs b/src/front/glsl/functions.rs index db6822de4b..a70df7e50b 100644 --- a/src/front/glsl/functions.rs +++ b/src/front/glsl/functions.rs @@ -1,8 +1,7 @@ use crate::{ - proc::ensure_block_returns, Arena, BinaryOperator, Binding, Block, BuiltIn, EntryPoint, - Expression, Function, FunctionArgument, FunctionResult, Handle, MathFunction, - RelationalFunction, SampleLevel, ScalarKind, ShaderStage, Statement, StructMember, - SwizzleComponent, Type, TypeInner, + proc::ensure_block_returns, Arena, BinaryOperator, Block, EntryPoint, Expression, Function, + FunctionArgument, FunctionResult, Handle, MathFunction, RelationalFunction, SampleLevel, + ScalarKind, Statement, StructMember, SwizzleComponent, Type, TypeInner, }; use super::{ast::*, error::ErrorKind, SourceMetadata}; @@ -392,13 +391,15 @@ impl Program<'_> { sig: FunctionSignature, qualifiers: Vec, meta: SourceMetadata, - ) -> Result<(), ErrorKind> { + ) -> Result, ErrorKind> { ensure_block_returns(&mut function.body); let stage = self.entry_points.get(&sig.name); - if let Some(&stage) = stage { + Ok(if let Some(&stage) = stage { let handle = self.module.functions.append(function); self.entries.push((sig.name, stage, handle)); + self.function_arg_use.push(Vec::new()); + handle } else { let void = function.result.is_none(); @@ -412,7 +413,9 @@ impl Program<'_> { decl.defined = true; *self.module.functions.get_mut(decl.handle) = function; + decl.handle } else { + self.function_arg_use.push(Vec::new()); let handle = self.module.functions.append(function); self.lookup_function.insert( sig, @@ -423,10 +426,9 @@ impl Program<'_> { void, }, ); + handle } - } - - Ok(()) + }) } pub fn add_prototype( @@ -438,6 +440,7 @@ impl Program<'_> { ) -> Result<(), ErrorKind> { let void = function.result.is_none(); + self.function_arg_use.push(Vec::new()); let handle = self.module.functions.append(function); if self @@ -462,17 +465,86 @@ impl Program<'_> { Ok(()) } + fn check_call_global( + &self, + caller: Handle, + function_arg_use: &mut [Vec], + stmt: &Statement, + ) { + match *stmt { + Statement::Block(ref block) => { + for stmt in block { + self.check_call_global(caller, function_arg_use, stmt) + } + } + Statement::If { + ref accept, + ref reject, + .. + } => { + for stmt in accept.iter().chain(reject.iter()) { + self.check_call_global(caller, function_arg_use, stmt) + } + } + Statement::Switch { + ref cases, + ref default, + .. + } => { + for stmt in cases + .iter() + .flat_map(|c| c.body.iter()) + .chain(default.iter()) + { + self.check_call_global(caller, function_arg_use, stmt) + } + } + Statement::Loop { + ref body, + ref continuing, + } => { + for stmt in body.iter().chain(continuing.iter()) { + self.check_call_global(caller, function_arg_use, stmt) + } + } + Statement::Call { function, .. } => { + let callee_len = function_arg_use[function.index()].len(); + let caller_len = function_arg_use[caller.index()].len(); + function_arg_use[caller.index()].extend( + std::iter::repeat(EntryArgUse::empty()) + .take(callee_len.saturating_sub(caller_len)), + ); + + for i in 0..callee_len.max(caller_len) { + let callee_use = function_arg_use[function.index()][i]; + function_arg_use[caller.index()][i] |= callee_use + } + } + _ => {} + } + } + pub fn add_entry_points(&mut self) { + let mut function_arg_use = Vec::new(); + std::mem::swap(&mut self.function_arg_use, &mut function_arg_use); + + for (handle, function) in self.module.functions.iter() { + for stmt in function.body.iter() { + self.check_call_global(handle, &mut function_arg_use, stmt) + } + } + for (name, stage, function) in self.entries.iter().cloned() { let mut arguments = Vec::new(); let mut expressions = Arena::new(); let mut body = Vec::new(); - for (binding, input, handle) in self.entry_args.iter().cloned() { - match binding { - Binding::Location { .. } if !input => continue, - Binding::BuiltIn(builtin) if !should_read(builtin, stage) => continue, - _ => {} + for (i, (binding, handle)) in self.entry_args.iter().cloned().enumerate() { + if function_arg_use[function.index()] + .get(i) + .map_or(true, |u| !u.contains(EntryArgUse::READ)) + { + continue; } let ty = self.module.global_variables[handle].ty; @@ -500,11 +572,12 @@ impl Program<'_> { let mut members = Vec::new(); let mut components = Vec::new(); - for (binding, input, handle) in self.entry_args.iter().cloned() { - match binding { - Binding::Location { .. } if input => continue, - Binding::BuiltIn(builtin) if !should_write(builtin, stage) => continue, - _ => {} + for (i, (binding, handle)) in self.entry_args.iter().cloned().enumerate() { + if function_arg_use[function.index()] + .get(i) + .map_or(true, |u| !u.contains(EntryArgUse::WRITE)) + { + continue; } let ty = self.module.global_variables[handle].ty; @@ -556,40 +629,3 @@ impl Program<'_> { } } } - -// FIXME: Both of the functions below should be removed they are a temporary solution -// -// The fix should analyze the entry point and children function calls -// (recursively) and store something like `GlobalUse` and then later only read -// or store the globals that need to be read or written in that stage - -fn should_read(built_in: BuiltIn, stage: ShaderStage) -> bool { - match (built_in, stage) { - (BuiltIn::Position, ShaderStage::Fragment) - | (BuiltIn::BaseInstance, ShaderStage::Vertex) - | (BuiltIn::BaseVertex, ShaderStage::Vertex) - | (BuiltIn::ClipDistance, ShaderStage::Fragment) - | (BuiltIn::InstanceIndex, ShaderStage::Vertex) - | (BuiltIn::VertexIndex, ShaderStage::Vertex) - | (BuiltIn::FrontFacing, ShaderStage::Fragment) - | (BuiltIn::SampleIndex, ShaderStage::Fragment) - | (BuiltIn::SampleMask, ShaderStage::Fragment) - | (BuiltIn::GlobalInvocationId, ShaderStage::Compute) - | (BuiltIn::LocalInvocationId, ShaderStage::Compute) - | (BuiltIn::LocalInvocationIndex, ShaderStage::Compute) - | (BuiltIn::WorkGroupId, ShaderStage::Compute) - | (BuiltIn::WorkGroupSize, ShaderStage::Compute) => true, - _ => false, - } -} - -fn should_write(built_in: BuiltIn, stage: ShaderStage) -> bool { - match (built_in, stage) { - (BuiltIn::Position, ShaderStage::Vertex) - | (BuiltIn::ClipDistance, ShaderStage::Vertex) - | (BuiltIn::PointSize, ShaderStage::Vertex) - | (BuiltIn::FragDepth, ShaderStage::Fragment) - | (BuiltIn::SampleMask, ShaderStage::Fragment) => true, - _ => false, - } -} diff --git a/src/front/glsl/parser.rs b/src/front/glsl/parser.rs index b995ea5d6c..6161f563fc 100644 --- a/src/front/glsl/parser.rs +++ b/src/front/glsl/parser.rs @@ -1,7 +1,8 @@ use super::{ ast::{ - Context, FunctionCall, FunctionCallKind, FunctionSignature, GlobalLookup, HirExpr, - HirExprKind, ParameterQualifier, Profile, StorageQualifier, StructLayout, TypeQualifier, + Context, FunctionCall, FunctionCallKind, FunctionSignature, GlobalLookup, GlobalLookupKind, + HirExpr, HirExprKind, ParameterQualifier, Profile, StorageQualifier, StructLayout, + TypeQualifier, }, error::ErrorKind, lex::Lexer, @@ -623,7 +624,8 @@ impl<'source, 'program, 'options> Parser<'source, 'program, 'options> { // parse the body self.parse_compound_statement(&mut context, &mut body)?; - self.program.add_function( + let Context { arg_use, .. } = context; + let handle = self.program.add_function( Function { name: Some(name), result, @@ -637,6 +639,8 @@ impl<'source, 'program, 'options> Parser<'source, 'program, 'options> { meta, )?; + self.program.function_arg_use[handle.index()] = arg_use; + Ok(true) } _ => Err(ErrorKind::InvalidToken(token)), @@ -800,9 +804,13 @@ impl<'source, 'program, 'options> Parser<'source, 'program, 'options> { }); if let Some(k) = name { - self.program - .global_variables - .push((k, GlobalLookup::Variable(handle))); + self.program.global_variables.push(( + k, + GlobalLookup { + kind: GlobalLookupKind::Variable(handle), + entry_arg: None, + }, + )); } for (i, k) in members @@ -810,9 +818,13 @@ impl<'source, 'program, 'options> Parser<'source, 'program, 'options> { .enumerate() .filter_map(|(i, m)| m.name.map(|s| (i as u32, s))) { - self.program - .global_variables - .push((k, GlobalLookup::BlockSelect(handle, i))); + self.program.global_variables.push(( + k, + GlobalLookup { + kind: GlobalLookupKind::BlockSelect(handle, i), + entry_arg: None, + }, + )); } Ok(true) diff --git a/src/front/glsl/variables.rs b/src/front/glsl/variables.rs index 8a3bbee5dd..966b639ee4 100644 --- a/src/front/glsl/variables.rs +++ b/src/front/glsl/variables.rs @@ -45,11 +45,17 @@ impl Program<'_> { storage_access: StorageAccess::empty(), }); - self.entry_args - .push((Binding::BuiltIn(builtin), true, handle)); + let idx = self.entry_args.len(); + self.entry_args.push((Binding::BuiltIn(builtin), handle)); - self.global_variables - .push((name.into(), GlobalLookup::Variable(handle))); + self.global_variables.push(( + name.into(), + GlobalLookup { + kind: GlobalLookupKind::Variable(handle), + entry_arg: Some(idx), + }, + )); + ctx.arg_use.push(EntryArgUse::empty()); let expr = ctx.add_expression(Expression::GlobalVariable(handle), body); let load = ctx.add_expression(Expression::Load { pointer: expr }, body); @@ -59,6 +65,7 @@ impl Program<'_> { expr, load: Some(load), mutable, + entry_arg: Some(idx), }, ); @@ -297,7 +304,6 @@ impl Program<'_> { } if let Some(location) = location { - let input = StorageQualifier::Input == storage; let interpolation = self.module.types[ty].inner.scalar_kind().map(|kind| { if let ScalarKind::Float = kind { Interpolation::Perspective @@ -315,18 +321,23 @@ impl Program<'_> { storage_access: StorageAccess::empty(), }); + let idx = self.entry_args.len(); self.entry_args.push(( Binding::Location { location, interpolation, sampling, }, - input, handle, )); - self.global_variables - .push((name, GlobalLookup::Variable(handle))); + self.global_variables.push(( + name, + GlobalLookup { + kind: GlobalLookupKind::Variable(handle), + entry_arg: Some(idx), + }, + )); return Ok(ctx.add_expression(Expression::GlobalVariable(handle), body)); } else if let StorageQualifier::Const = storage { @@ -369,8 +380,13 @@ impl Program<'_> { storage_access, }); - self.global_variables - .push((name, GlobalLookup::Variable(handle))); + self.global_variables.push(( + name, + GlobalLookup { + kind: GlobalLookupKind::Variable(handle), + entry_arg: None, + }, + )); Ok(ctx.add_expression(Expression::GlobalVariable(handle), body)) }