[glsl-in] Collect entry point arguments usage

This commit is contained in:
João Capucho
2021-06-01 19:06:14 +01:00
committed by Dzmitry Malyshau
parent 61bfb29963
commit 027634451d
4 changed files with 176 additions and 80 deletions

View File

@@ -12,11 +12,17 @@ use crate::{
};
#[derive(Debug, Clone, Copy)]
pub enum GlobalLookup {
pub enum GlobalLookupKind {
Variable(Handle<GlobalVariable>),
BlockSelect(Handle<GlobalVariable>, u32),
}
#[derive(Debug, Clone, Copy)]
pub struct GlobalLookup {
pub kind: GlobalLookupKind,
pub entry_arg: Option<usize>,
}
#[derive(Debug, PartialEq, Eq, Hash)]
pub struct FunctionSignature {
pub name: String,
@@ -33,6 +39,13 @@ pub struct FunctionDeclaration {
pub void: bool,
}
bitflags::bitflags! {
pub struct EntryArgUse: u32 {
const READ = 0x1;
const WRITE = 0x2;
}
}
#[derive(Debug)]
pub struct Program<'a> {
pub version: u16,
@@ -45,8 +58,10 @@ pub struct Program<'a> {
pub global_variables: Vec<(String, GlobalLookup)>,
pub constants: Vec<(String, Handle<Constant>)>,
pub entry_args: Vec<(Binding, bool, Handle<GlobalVariable>)>,
pub entry_args: Vec<(Binding, Handle<GlobalVariable>)>,
pub entries: Vec<(String, ShaderStage, Handle<Function>)>,
// TODO: More efficient representation
pub function_arg_use: Vec<Vec<EntryArgUse>>,
pub module: Module,
}
@@ -65,6 +80,7 @@ impl<'a> Program<'a> {
entry_args: Vec::new(),
entries: Vec::new(),
function_arg_use: Vec::new(),
module: Module::default(),
}
@@ -150,6 +166,7 @@ pub struct Context<'function> {
expressions: &'function mut Arena<Expression>,
pub locals: &'function mut Arena<LocalVariable>,
pub arguments: &'function mut Vec<FunctionArgument>,
pub arg_use: Vec<EntryArgUse>,
//TODO: Find less allocation heavy representation
pub scopes: Vec<FastHashMap<String, VariableReference>>,
@@ -173,6 +190,7 @@ impl<'function> Context<'function> {
expressions,
locals,
arguments,
arg_use: vec![EntryArgUse::empty(); program.entry_args.len()],
scopes: vec![FastHashMap::default()],
lookup_global_var_exps: FastHashMap::with_capacity_and_hasher(
@@ -194,6 +212,7 @@ impl<'function> Context<'function> {
expr,
load: None,
mutable: false,
entry_arg: None,
};
this.lookup_global_var_exps.insert(name.into(), var);
@@ -201,12 +220,13 @@ impl<'function> Context<'function> {
for &(ref name, lookup) in program.global_variables.iter() {
this.emit_flush(body);
let (expr, load) = match lookup {
GlobalLookup::Variable(v) => (
let GlobalLookup { kind, entry_arg } = lookup;
let (expr, load) = match kind {
GlobalLookupKind::Variable(v) => (
Expression::GlobalVariable(v),
program.module.global_variables[v].class != StorageClass::Handle,
),
GlobalLookup::BlockSelect(handle, index) => {
GlobalLookupKind::BlockSelect(handle, index) => {
let base = this.expressions.append(Expression::GlobalVariable(handle));
(Expression::AccessIndex { base, index }, true)
@@ -225,6 +245,7 @@ impl<'function> Context<'function> {
},
// TODO: respect constant qualifier
mutable: true,
entry_arg,
};
this.lookup_global_var_exps.insert(name.into(), var);
@@ -285,6 +306,7 @@ impl<'function> Context<'function> {
expr,
load: Some(load),
mutable,
entry_arg: None,
},
);
}
@@ -336,6 +358,7 @@ impl<'function> Context<'function> {
expr,
load,
mutable,
entry_arg: None,
},
);
}
@@ -455,8 +478,16 @@ impl<'function> Context<'function> {
));
}
if let Some(idx) = var.entry_arg {
self.arg_use[idx] |= EntryArgUse::WRITE
}
var.expr
} else {
if let Some(idx) = var.entry_arg {
self.arg_use[idx] |= EntryArgUse::READ
}
var.load.unwrap_or(var.expr)
}
}
@@ -617,6 +648,7 @@ pub struct VariableReference {
pub expr: Handle<Expression>,
pub load: Option<Handle<Expression>>,
pub mutable: bool,
pub entry_arg: Option<usize>,
}
#[derive(Debug, Clone)]

View File

@@ -1,8 +1,7 @@
use crate::{
proc::ensure_block_returns, Arena, BinaryOperator, Binding, Block, BuiltIn, EntryPoint,
Expression, Function, FunctionArgument, FunctionResult, Handle, MathFunction,
RelationalFunction, SampleLevel, ScalarKind, ShaderStage, Statement, StructMember,
SwizzleComponent, Type, TypeInner,
proc::ensure_block_returns, Arena, BinaryOperator, Block, EntryPoint, Expression, Function,
FunctionArgument, FunctionResult, Handle, MathFunction, RelationalFunction, SampleLevel,
ScalarKind, Statement, StructMember, SwizzleComponent, Type, TypeInner,
};
use super::{ast::*, error::ErrorKind, SourceMetadata};
@@ -392,13 +391,15 @@ impl Program<'_> {
sig: FunctionSignature,
qualifiers: Vec<ParameterQualifier>,
meta: SourceMetadata,
) -> Result<(), ErrorKind> {
) -> Result<Handle<Function>, ErrorKind> {
ensure_block_returns(&mut function.body);
let stage = self.entry_points.get(&sig.name);
if let Some(&stage) = stage {
Ok(if let Some(&stage) = stage {
let handle = self.module.functions.append(function);
self.entries.push((sig.name, stage, handle));
self.function_arg_use.push(Vec::new());
handle
} else {
let void = function.result.is_none();
@@ -412,7 +413,9 @@ impl Program<'_> {
decl.defined = true;
*self.module.functions.get_mut(decl.handle) = function;
decl.handle
} else {
self.function_arg_use.push(Vec::new());
let handle = self.module.functions.append(function);
self.lookup_function.insert(
sig,
@@ -423,10 +426,9 @@ impl Program<'_> {
void,
},
);
handle
}
}
Ok(())
})
}
pub fn add_prototype(
@@ -438,6 +440,7 @@ impl Program<'_> {
) -> Result<(), ErrorKind> {
let void = function.result.is_none();
self.function_arg_use.push(Vec::new());
let handle = self.module.functions.append(function);
if self
@@ -462,17 +465,86 @@ impl Program<'_> {
Ok(())
}
fn check_call_global(
&self,
caller: Handle<Function>,
function_arg_use: &mut [Vec<EntryArgUse>],
stmt: &Statement,
) {
match *stmt {
Statement::Block(ref block) => {
for stmt in block {
self.check_call_global(caller, function_arg_use, stmt)
}
}
Statement::If {
ref accept,
ref reject,
..
} => {
for stmt in accept.iter().chain(reject.iter()) {
self.check_call_global(caller, function_arg_use, stmt)
}
}
Statement::Switch {
ref cases,
ref default,
..
} => {
for stmt in cases
.iter()
.flat_map(|c| c.body.iter())
.chain(default.iter())
{
self.check_call_global(caller, function_arg_use, stmt)
}
}
Statement::Loop {
ref body,
ref continuing,
} => {
for stmt in body.iter().chain(continuing.iter()) {
self.check_call_global(caller, function_arg_use, stmt)
}
}
Statement::Call { function, .. } => {
let callee_len = function_arg_use[function.index()].len();
let caller_len = function_arg_use[caller.index()].len();
function_arg_use[caller.index()].extend(
std::iter::repeat(EntryArgUse::empty())
.take(callee_len.saturating_sub(caller_len)),
);
for i in 0..callee_len.max(caller_len) {
let callee_use = function_arg_use[function.index()][i];
function_arg_use[caller.index()][i] |= callee_use
}
}
_ => {}
}
}
pub fn add_entry_points(&mut self) {
let mut function_arg_use = Vec::new();
std::mem::swap(&mut self.function_arg_use, &mut function_arg_use);
for (handle, function) in self.module.functions.iter() {
for stmt in function.body.iter() {
self.check_call_global(handle, &mut function_arg_use, stmt)
}
}
for (name, stage, function) in self.entries.iter().cloned() {
let mut arguments = Vec::new();
let mut expressions = Arena::new();
let mut body = Vec::new();
for (binding, input, handle) in self.entry_args.iter().cloned() {
match binding {
Binding::Location { .. } if !input => continue,
Binding::BuiltIn(builtin) if !should_read(builtin, stage) => continue,
_ => {}
for (i, (binding, handle)) in self.entry_args.iter().cloned().enumerate() {
if function_arg_use[function.index()]
.get(i)
.map_or(true, |u| !u.contains(EntryArgUse::READ))
{
continue;
}
let ty = self.module.global_variables[handle].ty;
@@ -500,11 +572,12 @@ impl Program<'_> {
let mut members = Vec::new();
let mut components = Vec::new();
for (binding, input, handle) in self.entry_args.iter().cloned() {
match binding {
Binding::Location { .. } if input => continue,
Binding::BuiltIn(builtin) if !should_write(builtin, stage) => continue,
_ => {}
for (i, (binding, handle)) in self.entry_args.iter().cloned().enumerate() {
if function_arg_use[function.index()]
.get(i)
.map_or(true, |u| !u.contains(EntryArgUse::WRITE))
{
continue;
}
let ty = self.module.global_variables[handle].ty;
@@ -556,40 +629,3 @@ impl Program<'_> {
}
}
}
// FIXME: Both of the functions below should be removed they are a temporary solution
//
// The fix should analyze the entry point and children function calls
// (recursively) and store something like `GlobalUse` and then later only read
// or store the globals that need to be read or written in that stage
fn should_read(built_in: BuiltIn, stage: ShaderStage) -> bool {
match (built_in, stage) {
(BuiltIn::Position, ShaderStage::Fragment)
| (BuiltIn::BaseInstance, ShaderStage::Vertex)
| (BuiltIn::BaseVertex, ShaderStage::Vertex)
| (BuiltIn::ClipDistance, ShaderStage::Fragment)
| (BuiltIn::InstanceIndex, ShaderStage::Vertex)
| (BuiltIn::VertexIndex, ShaderStage::Vertex)
| (BuiltIn::FrontFacing, ShaderStage::Fragment)
| (BuiltIn::SampleIndex, ShaderStage::Fragment)
| (BuiltIn::SampleMask, ShaderStage::Fragment)
| (BuiltIn::GlobalInvocationId, ShaderStage::Compute)
| (BuiltIn::LocalInvocationId, ShaderStage::Compute)
| (BuiltIn::LocalInvocationIndex, ShaderStage::Compute)
| (BuiltIn::WorkGroupId, ShaderStage::Compute)
| (BuiltIn::WorkGroupSize, ShaderStage::Compute) => true,
_ => false,
}
}
fn should_write(built_in: BuiltIn, stage: ShaderStage) -> bool {
match (built_in, stage) {
(BuiltIn::Position, ShaderStage::Vertex)
| (BuiltIn::ClipDistance, ShaderStage::Vertex)
| (BuiltIn::PointSize, ShaderStage::Vertex)
| (BuiltIn::FragDepth, ShaderStage::Fragment)
| (BuiltIn::SampleMask, ShaderStage::Fragment) => true,
_ => false,
}
}

View File

@@ -1,7 +1,8 @@
use super::{
ast::{
Context, FunctionCall, FunctionCallKind, FunctionSignature, GlobalLookup, HirExpr,
HirExprKind, ParameterQualifier, Profile, StorageQualifier, StructLayout, TypeQualifier,
Context, FunctionCall, FunctionCallKind, FunctionSignature, GlobalLookup, GlobalLookupKind,
HirExpr, HirExprKind, ParameterQualifier, Profile, StorageQualifier, StructLayout,
TypeQualifier,
},
error::ErrorKind,
lex::Lexer,
@@ -623,7 +624,8 @@ impl<'source, 'program, 'options> Parser<'source, 'program, 'options> {
// parse the body
self.parse_compound_statement(&mut context, &mut body)?;
self.program.add_function(
let Context { arg_use, .. } = context;
let handle = self.program.add_function(
Function {
name: Some(name),
result,
@@ -637,6 +639,8 @@ impl<'source, 'program, 'options> Parser<'source, 'program, 'options> {
meta,
)?;
self.program.function_arg_use[handle.index()] = arg_use;
Ok(true)
}
_ => Err(ErrorKind::InvalidToken(token)),
@@ -800,9 +804,13 @@ impl<'source, 'program, 'options> Parser<'source, 'program, 'options> {
});
if let Some(k) = name {
self.program
.global_variables
.push((k, GlobalLookup::Variable(handle)));
self.program.global_variables.push((
k,
GlobalLookup {
kind: GlobalLookupKind::Variable(handle),
entry_arg: None,
},
));
}
for (i, k) in members
@@ -810,9 +818,13 @@ impl<'source, 'program, 'options> Parser<'source, 'program, 'options> {
.enumerate()
.filter_map(|(i, m)| m.name.map(|s| (i as u32, s)))
{
self.program
.global_variables
.push((k, GlobalLookup::BlockSelect(handle, i)));
self.program.global_variables.push((
k,
GlobalLookup {
kind: GlobalLookupKind::BlockSelect(handle, i),
entry_arg: None,
},
));
}
Ok(true)

View File

@@ -45,11 +45,17 @@ impl Program<'_> {
storage_access: StorageAccess::empty(),
});
self.entry_args
.push((Binding::BuiltIn(builtin), true, handle));
let idx = self.entry_args.len();
self.entry_args.push((Binding::BuiltIn(builtin), handle));
self.global_variables
.push((name.into(), GlobalLookup::Variable(handle)));
self.global_variables.push((
name.into(),
GlobalLookup {
kind: GlobalLookupKind::Variable(handle),
entry_arg: Some(idx),
},
));
ctx.arg_use.push(EntryArgUse::empty());
let expr = ctx.add_expression(Expression::GlobalVariable(handle), body);
let load = ctx.add_expression(Expression::Load { pointer: expr }, body);
@@ -59,6 +65,7 @@ impl Program<'_> {
expr,
load: Some(load),
mutable,
entry_arg: Some(idx),
},
);
@@ -297,7 +304,6 @@ impl Program<'_> {
}
if let Some(location) = location {
let input = StorageQualifier::Input == storage;
let interpolation = self.module.types[ty].inner.scalar_kind().map(|kind| {
if let ScalarKind::Float = kind {
Interpolation::Perspective
@@ -315,18 +321,23 @@ impl Program<'_> {
storage_access: StorageAccess::empty(),
});
let idx = self.entry_args.len();
self.entry_args.push((
Binding::Location {
location,
interpolation,
sampling,
},
input,
handle,
));
self.global_variables
.push((name, GlobalLookup::Variable(handle)));
self.global_variables.push((
name,
GlobalLookup {
kind: GlobalLookupKind::Variable(handle),
entry_arg: Some(idx),
},
));
return Ok(ctx.add_expression(Expression::GlobalVariable(handle), body));
} else if let StorageQualifier::Const = storage {
@@ -369,8 +380,13 @@ impl Program<'_> {
storage_access,
});
self.global_variables
.push((name, GlobalLookup::Variable(handle)));
self.global_variables.push((
name,
GlobalLookup {
kind: GlobalLookupKind::Variable(handle),
entry_arg: None,
},
));
Ok(ctx.add_expression(Expression::GlobalVariable(handle), body))
}