diff --git a/src/front/glsl/ast.rs b/src/front/glsl/ast.rs index ed2eecc8c1..aca29951d3 100644 --- a/src/front/glsl/ast.rs +++ b/src/front/glsl/ast.rs @@ -46,6 +46,31 @@ bitflags::bitflags! { } } +bitflags::bitflags! { + pub struct PrologueStage: u32 { + const VERTEX = 0x1; + const FRAGMENT = 0x2; + const COMPUTE = 0x4; + } +} + +impl From for PrologueStage { + fn from(stage: ShaderStage) -> Self { + match stage { + ShaderStage::Vertex => PrologueStage::VERTEX, + ShaderStage::Fragment => PrologueStage::FRAGMENT, + ShaderStage::Compute => PrologueStage::COMPUTE, + } + } +} + +#[derive(Debug)] +pub struct EntryArg { + pub binding: Binding, + pub handle: Handle, + pub prologue: PrologueStage, +} + #[derive(Debug)] pub struct Program<'a> { pub version: u16, @@ -58,7 +83,7 @@ pub struct Program<'a> { pub global_variables: Vec<(String, GlobalLookup)>, pub constants: Vec<(String, Handle)>, - pub entry_args: Vec<(Binding, Handle)>, + pub entry_args: Vec, pub entries: Vec<(String, ShaderStage, Handle)>, // TODO: More efficient representation pub function_arg_use: Vec>, diff --git a/src/front/glsl/functions.rs b/src/front/glsl/functions.rs index a70df7e50b..77bdfb5775 100644 --- a/src/front/glsl/functions.rs +++ b/src/front/glsl/functions.rs @@ -539,25 +539,26 @@ impl Program<'_> { let mut expressions = Arena::new(); let mut body = Vec::new(); - for (i, (binding, handle)) in self.entry_args.iter().cloned().enumerate() { + for (i, arg) in self.entry_args.iter().enumerate() { if function_arg_use[function.index()] .get(i) .map_or(true, |u| !u.contains(EntryArgUse::READ)) + || !arg.prologue.contains(stage.into()) { continue; } - let ty = self.module.global_variables[handle].ty; - let arg = arguments.len() as u32; + let ty = self.module.global_variables[arg.handle].ty; + let idx = arguments.len() as u32; arguments.push(FunctionArgument { name: None, ty, - binding: Some(binding), + binding: Some(arg.binding.clone()), }); - let pointer = expressions.append(Expression::GlobalVariable(handle)); - let value = expressions.append(Expression::FunctionArgument(arg)); + let pointer = expressions.append(Expression::GlobalVariable(arg.handle)); + let value = expressions.append(Expression::FunctionArgument(idx)); body.push(Statement::Store { pointer, value }); } @@ -572,7 +573,7 @@ impl Program<'_> { let mut members = Vec::new(); let mut components = Vec::new(); - for (i, (binding, handle)) in self.entry_args.iter().cloned().enumerate() { + for (i, arg) in self.entry_args.iter().enumerate() { if function_arg_use[function.index()] .get(i) .map_or(true, |u| !u.contains(EntryArgUse::WRITE)) @@ -580,18 +581,18 @@ impl Program<'_> { continue; } - let ty = self.module.global_variables[handle].ty; + let ty = self.module.global_variables[arg.handle].ty; members.push(StructMember { name: None, ty, - binding: Some(binding), + binding: Some(arg.binding.clone()), offset: span, }); span += self.module.types[ty].inner.span(&self.module.constants); - let pointer = expressions.append(Expression::GlobalVariable(handle)); + let pointer = expressions.append(Expression::GlobalVariable(arg.handle)); let len = expressions.len(); let load = expressions.append(Expression::Load { pointer }); body.push(Statement::Emit(expressions.range_from(len))); diff --git a/src/front/glsl/variables.rs b/src/front/glsl/variables.rs index 966b639ee4..399a9249fe 100644 --- a/src/front/glsl/variables.rs +++ b/src/front/glsl/variables.rs @@ -30,7 +30,7 @@ impl Program<'_> { return Ok(Some(global_var)); } - let mut add_builtin = |inner, builtin, mutable| { + let mut add_builtin = |inner, builtin, mutable, prologue| { let ty = self .module .types @@ -46,7 +46,11 @@ impl Program<'_> { }); let idx = self.entry_args.len(); - self.entry_args.push((Binding::BuiltIn(builtin), handle)); + self.entry_args.push(EntryArg { + binding: Binding::BuiltIn(builtin), + handle, + prologue, + }); self.global_variables.push(( name.into(), @@ -80,6 +84,7 @@ impl Program<'_> { }, BuiltIn::Position, true, + PrologueStage::FRAGMENT, ), "gl_VertexIndex" => add_builtin( TypeInner::Scalar { @@ -88,6 +93,7 @@ impl Program<'_> { }, BuiltIn::VertexIndex, false, + PrologueStage::VERTEX, ), "gl_InstanceIndex" => add_builtin( TypeInner::Scalar { @@ -96,6 +102,7 @@ impl Program<'_> { }, BuiltIn::InstanceIndex, false, + PrologueStage::VERTEX, ), _ => Ok(None), } @@ -304,6 +311,11 @@ impl Program<'_> { } if let Some(location) = location { + let prologue = if let StorageQualifier::Input = storage { + PrologueStage::all() + } else { + PrologueStage::empty() + }; let interpolation = self.module.types[ty].inner.scalar_kind().map(|kind| { if let ScalarKind::Float = kind { Interpolation::Perspective @@ -322,14 +334,15 @@ impl Program<'_> { }); let idx = self.entry_args.len(); - self.entry_args.push(( - Binding::Location { + self.entry_args.push(EntryArg { + binding: Binding::Location { location, interpolation, sampling, }, handle, - )); + prologue, + }); self.global_variables.push(( name,