mirror of
https://github.com/gfx-rs/wgpu.git
synced 2026-04-22 03:02:01 -04:00
[glsl-in] Collect entry point arguments usage
This commit is contained in:
committed by
Dzmitry Malyshau
parent
61bfb29963
commit
027634451d
@@ -12,11 +12,17 @@ use crate::{
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum GlobalLookup {
|
||||
pub enum GlobalLookupKind {
|
||||
Variable(Handle<GlobalVariable>),
|
||||
BlockSelect(Handle<GlobalVariable>, u32),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct GlobalLookup {
|
||||
pub kind: GlobalLookupKind,
|
||||
pub entry_arg: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Hash)]
|
||||
pub struct FunctionSignature {
|
||||
pub name: String,
|
||||
@@ -33,6 +39,13 @@ pub struct FunctionDeclaration {
|
||||
pub void: bool,
|
||||
}
|
||||
|
||||
bitflags::bitflags! {
|
||||
pub struct EntryArgUse: u32 {
|
||||
const READ = 0x1;
|
||||
const WRITE = 0x2;
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Program<'a> {
|
||||
pub version: u16,
|
||||
@@ -45,8 +58,10 @@ pub struct Program<'a> {
|
||||
pub global_variables: Vec<(String, GlobalLookup)>,
|
||||
pub constants: Vec<(String, Handle<Constant>)>,
|
||||
|
||||
pub entry_args: Vec<(Binding, bool, Handle<GlobalVariable>)>,
|
||||
pub entry_args: Vec<(Binding, Handle<GlobalVariable>)>,
|
||||
pub entries: Vec<(String, ShaderStage, Handle<Function>)>,
|
||||
// TODO: More efficient representation
|
||||
pub function_arg_use: Vec<Vec<EntryArgUse>>,
|
||||
|
||||
pub module: Module,
|
||||
}
|
||||
@@ -65,6 +80,7 @@ impl<'a> Program<'a> {
|
||||
|
||||
entry_args: Vec::new(),
|
||||
entries: Vec::new(),
|
||||
function_arg_use: Vec::new(),
|
||||
|
||||
module: Module::default(),
|
||||
}
|
||||
@@ -150,6 +166,7 @@ pub struct Context<'function> {
|
||||
expressions: &'function mut Arena<Expression>,
|
||||
pub locals: &'function mut Arena<LocalVariable>,
|
||||
pub arguments: &'function mut Vec<FunctionArgument>,
|
||||
pub arg_use: Vec<EntryArgUse>,
|
||||
|
||||
//TODO: Find less allocation heavy representation
|
||||
pub scopes: Vec<FastHashMap<String, VariableReference>>,
|
||||
@@ -173,6 +190,7 @@ impl<'function> Context<'function> {
|
||||
expressions,
|
||||
locals,
|
||||
arguments,
|
||||
arg_use: vec![EntryArgUse::empty(); program.entry_args.len()],
|
||||
|
||||
scopes: vec![FastHashMap::default()],
|
||||
lookup_global_var_exps: FastHashMap::with_capacity_and_hasher(
|
||||
@@ -194,6 +212,7 @@ impl<'function> Context<'function> {
|
||||
expr,
|
||||
load: None,
|
||||
mutable: false,
|
||||
entry_arg: None,
|
||||
};
|
||||
|
||||
this.lookup_global_var_exps.insert(name.into(), var);
|
||||
@@ -201,12 +220,13 @@ impl<'function> Context<'function> {
|
||||
|
||||
for &(ref name, lookup) in program.global_variables.iter() {
|
||||
this.emit_flush(body);
|
||||
let (expr, load) = match lookup {
|
||||
GlobalLookup::Variable(v) => (
|
||||
let GlobalLookup { kind, entry_arg } = lookup;
|
||||
let (expr, load) = match kind {
|
||||
GlobalLookupKind::Variable(v) => (
|
||||
Expression::GlobalVariable(v),
|
||||
program.module.global_variables[v].class != StorageClass::Handle,
|
||||
),
|
||||
GlobalLookup::BlockSelect(handle, index) => {
|
||||
GlobalLookupKind::BlockSelect(handle, index) => {
|
||||
let base = this.expressions.append(Expression::GlobalVariable(handle));
|
||||
|
||||
(Expression::AccessIndex { base, index }, true)
|
||||
@@ -225,6 +245,7 @@ impl<'function> Context<'function> {
|
||||
},
|
||||
// TODO: respect constant qualifier
|
||||
mutable: true,
|
||||
entry_arg,
|
||||
};
|
||||
|
||||
this.lookup_global_var_exps.insert(name.into(), var);
|
||||
@@ -285,6 +306,7 @@ impl<'function> Context<'function> {
|
||||
expr,
|
||||
load: Some(load),
|
||||
mutable,
|
||||
entry_arg: None,
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -336,6 +358,7 @@ impl<'function> Context<'function> {
|
||||
expr,
|
||||
load,
|
||||
mutable,
|
||||
entry_arg: None,
|
||||
},
|
||||
);
|
||||
}
|
||||
@@ -455,8 +478,16 @@ impl<'function> Context<'function> {
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(idx) = var.entry_arg {
|
||||
self.arg_use[idx] |= EntryArgUse::WRITE
|
||||
}
|
||||
|
||||
var.expr
|
||||
} else {
|
||||
if let Some(idx) = var.entry_arg {
|
||||
self.arg_use[idx] |= EntryArgUse::READ
|
||||
}
|
||||
|
||||
var.load.unwrap_or(var.expr)
|
||||
}
|
||||
}
|
||||
@@ -617,6 +648,7 @@ pub struct VariableReference {
|
||||
pub expr: Handle<Expression>,
|
||||
pub load: Option<Handle<Expression>>,
|
||||
pub mutable: bool,
|
||||
pub entry_arg: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
use crate::{
|
||||
proc::ensure_block_returns, Arena, BinaryOperator, Binding, Block, BuiltIn, EntryPoint,
|
||||
Expression, Function, FunctionArgument, FunctionResult, Handle, MathFunction,
|
||||
RelationalFunction, SampleLevel, ScalarKind, ShaderStage, Statement, StructMember,
|
||||
SwizzleComponent, Type, TypeInner,
|
||||
proc::ensure_block_returns, Arena, BinaryOperator, Block, EntryPoint, Expression, Function,
|
||||
FunctionArgument, FunctionResult, Handle, MathFunction, RelationalFunction, SampleLevel,
|
||||
ScalarKind, Statement, StructMember, SwizzleComponent, Type, TypeInner,
|
||||
};
|
||||
|
||||
use super::{ast::*, error::ErrorKind, SourceMetadata};
|
||||
@@ -392,13 +391,15 @@ impl Program<'_> {
|
||||
sig: FunctionSignature,
|
||||
qualifiers: Vec<ParameterQualifier>,
|
||||
meta: SourceMetadata,
|
||||
) -> Result<(), ErrorKind> {
|
||||
) -> Result<Handle<Function>, ErrorKind> {
|
||||
ensure_block_returns(&mut function.body);
|
||||
let stage = self.entry_points.get(&sig.name);
|
||||
|
||||
if let Some(&stage) = stage {
|
||||
Ok(if let Some(&stage) = stage {
|
||||
let handle = self.module.functions.append(function);
|
||||
self.entries.push((sig.name, stage, handle));
|
||||
self.function_arg_use.push(Vec::new());
|
||||
handle
|
||||
} else {
|
||||
let void = function.result.is_none();
|
||||
|
||||
@@ -412,7 +413,9 @@ impl Program<'_> {
|
||||
|
||||
decl.defined = true;
|
||||
*self.module.functions.get_mut(decl.handle) = function;
|
||||
decl.handle
|
||||
} else {
|
||||
self.function_arg_use.push(Vec::new());
|
||||
let handle = self.module.functions.append(function);
|
||||
self.lookup_function.insert(
|
||||
sig,
|
||||
@@ -423,10 +426,9 @@ impl Program<'_> {
|
||||
void,
|
||||
},
|
||||
);
|
||||
handle
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
pub fn add_prototype(
|
||||
@@ -438,6 +440,7 @@ impl Program<'_> {
|
||||
) -> Result<(), ErrorKind> {
|
||||
let void = function.result.is_none();
|
||||
|
||||
self.function_arg_use.push(Vec::new());
|
||||
let handle = self.module.functions.append(function);
|
||||
|
||||
if self
|
||||
@@ -462,17 +465,86 @@ impl Program<'_> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn check_call_global(
|
||||
&self,
|
||||
caller: Handle<Function>,
|
||||
function_arg_use: &mut [Vec<EntryArgUse>],
|
||||
stmt: &Statement,
|
||||
) {
|
||||
match *stmt {
|
||||
Statement::Block(ref block) => {
|
||||
for stmt in block {
|
||||
self.check_call_global(caller, function_arg_use, stmt)
|
||||
}
|
||||
}
|
||||
Statement::If {
|
||||
ref accept,
|
||||
ref reject,
|
||||
..
|
||||
} => {
|
||||
for stmt in accept.iter().chain(reject.iter()) {
|
||||
self.check_call_global(caller, function_arg_use, stmt)
|
||||
}
|
||||
}
|
||||
Statement::Switch {
|
||||
ref cases,
|
||||
ref default,
|
||||
..
|
||||
} => {
|
||||
for stmt in cases
|
||||
.iter()
|
||||
.flat_map(|c| c.body.iter())
|
||||
.chain(default.iter())
|
||||
{
|
||||
self.check_call_global(caller, function_arg_use, stmt)
|
||||
}
|
||||
}
|
||||
Statement::Loop {
|
||||
ref body,
|
||||
ref continuing,
|
||||
} => {
|
||||
for stmt in body.iter().chain(continuing.iter()) {
|
||||
self.check_call_global(caller, function_arg_use, stmt)
|
||||
}
|
||||
}
|
||||
Statement::Call { function, .. } => {
|
||||
let callee_len = function_arg_use[function.index()].len();
|
||||
let caller_len = function_arg_use[caller.index()].len();
|
||||
function_arg_use[caller.index()].extend(
|
||||
std::iter::repeat(EntryArgUse::empty())
|
||||
.take(callee_len.saturating_sub(caller_len)),
|
||||
);
|
||||
|
||||
for i in 0..callee_len.max(caller_len) {
|
||||
let callee_use = function_arg_use[function.index()][i];
|
||||
function_arg_use[caller.index()][i] |= callee_use
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_entry_points(&mut self) {
|
||||
let mut function_arg_use = Vec::new();
|
||||
std::mem::swap(&mut self.function_arg_use, &mut function_arg_use);
|
||||
|
||||
for (handle, function) in self.module.functions.iter() {
|
||||
for stmt in function.body.iter() {
|
||||
self.check_call_global(handle, &mut function_arg_use, stmt)
|
||||
}
|
||||
}
|
||||
|
||||
for (name, stage, function) in self.entries.iter().cloned() {
|
||||
let mut arguments = Vec::new();
|
||||
let mut expressions = Arena::new();
|
||||
let mut body = Vec::new();
|
||||
|
||||
for (binding, input, handle) in self.entry_args.iter().cloned() {
|
||||
match binding {
|
||||
Binding::Location { .. } if !input => continue,
|
||||
Binding::BuiltIn(builtin) if !should_read(builtin, stage) => continue,
|
||||
_ => {}
|
||||
for (i, (binding, handle)) in self.entry_args.iter().cloned().enumerate() {
|
||||
if function_arg_use[function.index()]
|
||||
.get(i)
|
||||
.map_or(true, |u| !u.contains(EntryArgUse::READ))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let ty = self.module.global_variables[handle].ty;
|
||||
@@ -500,11 +572,12 @@ impl Program<'_> {
|
||||
let mut members = Vec::new();
|
||||
let mut components = Vec::new();
|
||||
|
||||
for (binding, input, handle) in self.entry_args.iter().cloned() {
|
||||
match binding {
|
||||
Binding::Location { .. } if input => continue,
|
||||
Binding::BuiltIn(builtin) if !should_write(builtin, stage) => continue,
|
||||
_ => {}
|
||||
for (i, (binding, handle)) in self.entry_args.iter().cloned().enumerate() {
|
||||
if function_arg_use[function.index()]
|
||||
.get(i)
|
||||
.map_or(true, |u| !u.contains(EntryArgUse::WRITE))
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let ty = self.module.global_variables[handle].ty;
|
||||
@@ -556,40 +629,3 @@ impl Program<'_> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: Both of the functions below should be removed they are a temporary solution
|
||||
//
|
||||
// The fix should analyze the entry point and children function calls
|
||||
// (recursively) and store something like `GlobalUse` and then later only read
|
||||
// or store the globals that need to be read or written in that stage
|
||||
|
||||
fn should_read(built_in: BuiltIn, stage: ShaderStage) -> bool {
|
||||
match (built_in, stage) {
|
||||
(BuiltIn::Position, ShaderStage::Fragment)
|
||||
| (BuiltIn::BaseInstance, ShaderStage::Vertex)
|
||||
| (BuiltIn::BaseVertex, ShaderStage::Vertex)
|
||||
| (BuiltIn::ClipDistance, ShaderStage::Fragment)
|
||||
| (BuiltIn::InstanceIndex, ShaderStage::Vertex)
|
||||
| (BuiltIn::VertexIndex, ShaderStage::Vertex)
|
||||
| (BuiltIn::FrontFacing, ShaderStage::Fragment)
|
||||
| (BuiltIn::SampleIndex, ShaderStage::Fragment)
|
||||
| (BuiltIn::SampleMask, ShaderStage::Fragment)
|
||||
| (BuiltIn::GlobalInvocationId, ShaderStage::Compute)
|
||||
| (BuiltIn::LocalInvocationId, ShaderStage::Compute)
|
||||
| (BuiltIn::LocalInvocationIndex, ShaderStage::Compute)
|
||||
| (BuiltIn::WorkGroupId, ShaderStage::Compute)
|
||||
| (BuiltIn::WorkGroupSize, ShaderStage::Compute) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn should_write(built_in: BuiltIn, stage: ShaderStage) -> bool {
|
||||
match (built_in, stage) {
|
||||
(BuiltIn::Position, ShaderStage::Vertex)
|
||||
| (BuiltIn::ClipDistance, ShaderStage::Vertex)
|
||||
| (BuiltIn::PointSize, ShaderStage::Vertex)
|
||||
| (BuiltIn::FragDepth, ShaderStage::Fragment)
|
||||
| (BuiltIn::SampleMask, ShaderStage::Fragment) => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
use super::{
|
||||
ast::{
|
||||
Context, FunctionCall, FunctionCallKind, FunctionSignature, GlobalLookup, HirExpr,
|
||||
HirExprKind, ParameterQualifier, Profile, StorageQualifier, StructLayout, TypeQualifier,
|
||||
Context, FunctionCall, FunctionCallKind, FunctionSignature, GlobalLookup, GlobalLookupKind,
|
||||
HirExpr, HirExprKind, ParameterQualifier, Profile, StorageQualifier, StructLayout,
|
||||
TypeQualifier,
|
||||
},
|
||||
error::ErrorKind,
|
||||
lex::Lexer,
|
||||
@@ -623,7 +624,8 @@ impl<'source, 'program, 'options> Parser<'source, 'program, 'options> {
|
||||
// parse the body
|
||||
self.parse_compound_statement(&mut context, &mut body)?;
|
||||
|
||||
self.program.add_function(
|
||||
let Context { arg_use, .. } = context;
|
||||
let handle = self.program.add_function(
|
||||
Function {
|
||||
name: Some(name),
|
||||
result,
|
||||
@@ -637,6 +639,8 @@ impl<'source, 'program, 'options> Parser<'source, 'program, 'options> {
|
||||
meta,
|
||||
)?;
|
||||
|
||||
self.program.function_arg_use[handle.index()] = arg_use;
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
_ => Err(ErrorKind::InvalidToken(token)),
|
||||
@@ -800,9 +804,13 @@ impl<'source, 'program, 'options> Parser<'source, 'program, 'options> {
|
||||
});
|
||||
|
||||
if let Some(k) = name {
|
||||
self.program
|
||||
.global_variables
|
||||
.push((k, GlobalLookup::Variable(handle)));
|
||||
self.program.global_variables.push((
|
||||
k,
|
||||
GlobalLookup {
|
||||
kind: GlobalLookupKind::Variable(handle),
|
||||
entry_arg: None,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
for (i, k) in members
|
||||
@@ -810,9 +818,13 @@ impl<'source, 'program, 'options> Parser<'source, 'program, 'options> {
|
||||
.enumerate()
|
||||
.filter_map(|(i, m)| m.name.map(|s| (i as u32, s)))
|
||||
{
|
||||
self.program
|
||||
.global_variables
|
||||
.push((k, GlobalLookup::BlockSelect(handle, i)));
|
||||
self.program.global_variables.push((
|
||||
k,
|
||||
GlobalLookup {
|
||||
kind: GlobalLookupKind::BlockSelect(handle, i),
|
||||
entry_arg: None,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
Ok(true)
|
||||
|
||||
@@ -45,11 +45,17 @@ impl Program<'_> {
|
||||
storage_access: StorageAccess::empty(),
|
||||
});
|
||||
|
||||
self.entry_args
|
||||
.push((Binding::BuiltIn(builtin), true, handle));
|
||||
let idx = self.entry_args.len();
|
||||
self.entry_args.push((Binding::BuiltIn(builtin), handle));
|
||||
|
||||
self.global_variables
|
||||
.push((name.into(), GlobalLookup::Variable(handle)));
|
||||
self.global_variables.push((
|
||||
name.into(),
|
||||
GlobalLookup {
|
||||
kind: GlobalLookupKind::Variable(handle),
|
||||
entry_arg: Some(idx),
|
||||
},
|
||||
));
|
||||
ctx.arg_use.push(EntryArgUse::empty());
|
||||
|
||||
let expr = ctx.add_expression(Expression::GlobalVariable(handle), body);
|
||||
let load = ctx.add_expression(Expression::Load { pointer: expr }, body);
|
||||
@@ -59,6 +65,7 @@ impl Program<'_> {
|
||||
expr,
|
||||
load: Some(load),
|
||||
mutable,
|
||||
entry_arg: Some(idx),
|
||||
},
|
||||
);
|
||||
|
||||
@@ -297,7 +304,6 @@ impl Program<'_> {
|
||||
}
|
||||
|
||||
if let Some(location) = location {
|
||||
let input = StorageQualifier::Input == storage;
|
||||
let interpolation = self.module.types[ty].inner.scalar_kind().map(|kind| {
|
||||
if let ScalarKind::Float = kind {
|
||||
Interpolation::Perspective
|
||||
@@ -315,18 +321,23 @@ impl Program<'_> {
|
||||
storage_access: StorageAccess::empty(),
|
||||
});
|
||||
|
||||
let idx = self.entry_args.len();
|
||||
self.entry_args.push((
|
||||
Binding::Location {
|
||||
location,
|
||||
interpolation,
|
||||
sampling,
|
||||
},
|
||||
input,
|
||||
handle,
|
||||
));
|
||||
|
||||
self.global_variables
|
||||
.push((name, GlobalLookup::Variable(handle)));
|
||||
self.global_variables.push((
|
||||
name,
|
||||
GlobalLookup {
|
||||
kind: GlobalLookupKind::Variable(handle),
|
||||
entry_arg: Some(idx),
|
||||
},
|
||||
));
|
||||
|
||||
return Ok(ctx.add_expression(Expression::GlobalVariable(handle), body));
|
||||
} else if let StorageQualifier::Const = storage {
|
||||
@@ -369,8 +380,13 @@ impl Program<'_> {
|
||||
storage_access,
|
||||
});
|
||||
|
||||
self.global_variables
|
||||
.push((name, GlobalLookup::Variable(handle)));
|
||||
self.global_variables.push((
|
||||
name,
|
||||
GlobalLookup {
|
||||
kind: GlobalLookupKind::Variable(handle),
|
||||
entry_arg: None,
|
||||
},
|
||||
));
|
||||
|
||||
Ok(ctx.add_expression(Expression::GlobalVariable(handle), body))
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user