From 2ebaadaf0ca62e558ac027606f6f18c807308ee7 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Thu, 17 Sep 2020 00:35:32 -0400 Subject: [PATCH] Refactor entry point IR --- examples/convert.rs | 34 +-- src/back/glsl.rs | 46 +--- src/back/msl.rs | 393 ++++++++++++++------------- src/back/spv/writer.rs | 221 +++++++-------- src/front/glsl/ast.rs | 18 +- src/front/glsl/mod.rs | 21 +- src/front/glsl/parser.rs | 54 ++-- src/front/glsl/parser_tests.rs | 2 +- src/front/glsl/variables.rs | 70 ++--- src/front/mod.rs | 2 +- src/front/spv/function.rs | 245 +++++++++-------- src/front/spv/mod.rs | 128 ++++----- src/front/wgsl/mod.rs | 127 +++++---- src/lib.rs | 32 +-- src/proc/interface.rs | 51 ++-- test-data/boids.param.ron | 6 +- test-data/boids.wgsl | 51 ++-- test-data/function.wgsl | 5 +- test-data/quad.wgsl | 20 +- test-data/simple/simple.expected.ron | 104 ++++--- test-data/simple/simple.wgsl | 6 +- tests/convert.rs | 48 +++- 22 files changed, 846 insertions(+), 838 deletions(-) diff --git a/examples/convert.rs b/examples/convert.rs index 3b219540c5..dda2eb400f 100644 --- a/examples/convert.rs +++ b/examples/convert.rs @@ -3,7 +3,7 @@ use std::{env, fs, path::Path}; #[derive(Hash, PartialEq, Eq, Serialize, Deserialize)] struct BindSource { - set: u32, + group: u32, binding: u32, } @@ -54,26 +54,14 @@ fn main() { #[cfg(feature = "glsl-in")] "frag" => { let input = fs::read_to_string(&args[1]).unwrap(); - naga::front::glsl::parse_str( - &input, - "main".to_string(), - naga::ShaderStage::Fragment { - early_depth_test: None, - }, - ) - .unwrap() + naga::front::glsl::parse_str(&input, "main".to_string(), naga::ShaderStage::Fragment) + .unwrap() } #[cfg(feature = "glsl-in")] "comp" => { let input = fs::read_to_string(&args[1]).unwrap(); - naga::front::glsl::parse_str( - &input, - "main".to_string(), - naga::ShaderStage::Compute { - local_size: (0, 0, 0), - }, - ) - .unwrap() + naga::front::glsl::parse_str(&input, "main".to_string(), naga::ShaderStage::Compute) + .unwrap() } #[cfg(feature = "deserialize")] "ron" => { @@ -113,7 +101,7 @@ fn main() { for (key, value) in params.metal_bindings { binding_map.insert( msl::BindSource { - set: key.set, + group: key.group, binding: key.binding, }, msl::BindTarget { @@ -170,17 +158,13 @@ fn main() { let options = Options { version: Version::Embedded(310), entry_point: ( - String::from("main"), match stage { "vert" => ShaderStage::Vertex, - "frag" => ShaderStage::Fragment { - early_depth_test: None, - }, - "comp" => ShaderStage::Compute { - local_size: (0, 0, 0), - }, + "frag" => ShaderStage::Fragment, + "comp" => ShaderStage::Compute, _ => unreachable!(), }, + String::from("main"), ), }; diff --git a/src/back/glsl.rs b/src/back/glsl.rs index cbb8fd26ac..c54282cc5f 100644 --- a/src/back/glsl.rs +++ b/src/back/glsl.rs @@ -66,7 +66,7 @@ impl fmt::Display for Version { #[derive(Debug, Clone)] pub struct Options { pub version: Version, - pub entry_point: (String, ShaderStage), + pub entry_point: (ShaderStage, String), } #[derive(Debug, Clone)] @@ -141,12 +141,12 @@ pub fn write<'a>( let entry_point = module .entry_points - .iter() - .find(|entry| entry.name == options.entry_point.0 && entry.stage == options.entry_point.1) + .get(&options.entry_point) .ok_or_else(|| Error::Custom(String::from("Entry point not found")))?; - let func = &module.functions[entry_point.function]; + let func = &entry_point.function; + let stage = options.entry_point.0; - if let ShaderStage::Compute { .. } = entry_point.stage { + if let ShaderStage::Compute = stage { if (es && version < 310) || (!es && version < 430) { return Err(Error::Custom(format!( "Version {} doesn't support compute shaders", @@ -233,21 +233,7 @@ pub fn write<'a>( let mut functions = FastHashMap::default(); for (handle, func) in module.functions.iter() { - // Discard all entry points - if entry_point.function != handle - && module - .entry_points - .iter() - .any(|entry| entry.function == handle) - { - continue; - } - - let name = if entry_point.function != handle { - namer(func.name.as_ref()) - } else { - String::from("main") - }; + let name = namer(func.name.as_ref()); writeln!( out, @@ -406,8 +392,8 @@ pub fn write<'a>( } if let Some(interpolation) = global.interpolation { - match (entry_point.stage, global.class) { - (ShaderStage::Fragment { .. }, StorageClass::Input) + match (stage, global.class) { + (ShaderStage::Fragment, StorageClass::Input) | (ShaderStage::Vertex, StorageClass::Output) => { write!(out, "{} ", write_interpolation(interpolation)?)?; } @@ -1449,7 +1435,7 @@ fn write_format_glsl(format: StorageFormat) -> &'static str { struct TextureMappingVisitor<'a> { expressions: &'a Arena, - map: FastHashMap, Option>>, + map: &'a mut FastHashMap, Option>>, error: Option, } @@ -1503,19 +1489,15 @@ fn collect_texture_mapping( for function in functions.keys() { let func = &module.functions[*function]; - let mut visitor = TextureMappingVisitor { - expressions: &func.expressions, - map: FastHashMap::default(), - error: None, - }; - let mut interface = Interface { expressions: &func.expressions, - visitor: &mut visitor, + visitor: TextureMappingVisitor { + expressions: &func.expressions, + map: &mut mappings, + error: None, + }, }; interface.traverse(&func.body); - - mappings.extend(visitor.map); } Ok(mappings) diff --git a/src/back/msl.rs b/src/back/msl.rs index ed7e16027e..607b11aced 100644 --- a/src/back/msl.rs +++ b/src/back/msl.rs @@ -31,7 +31,7 @@ pub struct BindTarget { #[derive(Clone, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] pub struct BindSource { - pub set: u32, + pub group: u32, pub binding: u32, } @@ -64,7 +64,6 @@ pub enum Error { Format(FmtError), Type(ResolveError), UnexpectedLocation, - MixedExecutionModels(Handle), MissingBinding(Handle), MissingBindTarget(BindSource), InvalidImageAccess(crate::StorageAccess), @@ -125,8 +124,8 @@ impl Options<'_> { }), LocationMode::Uniform => Err(Error::UnexpectedLocation), }, - crate::Binding::Descriptor { set, binding } => { - let source = BindSource { set, binding }; + crate::Binding::Resource { group, binding } => { + let source = BindSource { group, binding }; self.binding_map .get(&source) .cloned() @@ -182,22 +181,6 @@ impl Indexed for ParameterIndex { self.0 } } -struct InputStructIndex(Handle); -impl Indexed for InputStructIndex { - const CLASS: &'static str = "Input"; - const PREFIX: bool = true; - fn id(&self) -> usize { - self.0.index() - } -} -struct OutputStructIndex(Handle); -impl Indexed for OutputStructIndex { - const CLASS: &'static str = "Output"; - const PREFIX: bool = true; - fn id(&self) -> usize { - self.0.index() - } -} enum NameSource<'a> { Custom { name: &'a str, prefix: bool }, @@ -924,8 +907,55 @@ impl Writer { )?; let fun_name = fun.name.or_index(fun_handle); + let result_type_handle = fun.return_type.unwrap(); + let result_type_name = module.types[result_type_handle] + .name + .or_index(result_type_handle); + writeln!(self.out, "{} {}(", result_type_name, fun_name)?; + + for (index, &ty) in fun.parameter_types.iter().enumerate() { + let name = Name::from(ParameterIndex(index)); + let member_type_name = module.types[ty].name.or_index(ty); + let separator = separate(index + 1 == fun.parameter_types.len()); + writeln!(self.out, "\t{} {}{}", member_type_name, name, separator)?; + } + writeln!(self.out, ") {{")?; + + for (local_handle, local) in fun.local_variables.iter() { + let ty_name = module.types[local.ty].name.or_index(local.ty); + write!( + self.out, + "\t{} {}", + ty_name, + local.name.or_index(local_handle) + )?; + if let Some(value) = local.init { + write!(self.out, " = ")?; + self.put_expression(value, fun, module)?; + } + writeln!(self.out, ";")?; + } + for statement in fun.body.iter() { + self.put_statement(Level(1), statement, fun, false, module)?; + } + writeln!(self.out, "}}")?; + } + + for (&(stage, ref name), ep) in module.entry_points.iter() { + let fun = &ep.function; + self.typifier.resolve_all( + &fun.expressions, + &module.types, + &ResolveContext { + constants: &module.constants, + global_vars: &module.global_variables, + local_vars: &fun.local_variables, + functions: &module.functions, + parameter_types: &fun.parameter_types, + }, + )?; + // find the entry point(s) and inputs/outputs - let mut shader_stage = None; let mut last_used_global = None; for ((handle, var), &usage) in module.global_variables.iter().zip(&fun.global_usage) { match var.class { @@ -941,206 +971,177 @@ impl Writer { last_used_global = Some(handle); } } - for ep in module.entry_points.iter() { - if ep.function == fun_handle { - if shader_stage.is_some() { - if shader_stage != Some(ep.stage) { - return Err(Error::MixedExecutionModels(fun_handle)); - } - } else { - shader_stage = Some(ep.stage); - } + + let output_name = Name { + class: "Output", + source: NameSource::Custom { name, prefix: true }, + }; + + let (em_str, in_mode, out_mode) = match stage { + crate::ShaderStage::Vertex => ( + "vertex", + LocationMode::VertexInput, + LocationMode::Intermediate, + ), + crate::ShaderStage::Fragment { .. } => ( + "fragment", + LocationMode::Intermediate, + LocationMode::FragmentOutput, + ), + crate::ShaderStage::Compute { .. } => { + ("kernel", LocationMode::Uniform, LocationMode::Uniform) } - } - let output_name = fun.name.or_index(OutputStructIndex(fun_handle)); + }; + let location_input_name = Name { + class: "Input", + source: NameSource::Custom { name, prefix: true }, + }; - // make dedicated input/output structs - if let Some(stage) = shader_stage { - assert_eq!(fun.return_type, None); - let (em_str, in_mode, out_mode) = match stage { - crate::ShaderStage::Vertex => ( - "vertex", - LocationMode::VertexInput, - LocationMode::Intermediate, - ), - crate::ShaderStage::Fragment { .. } => ( - "fragment", - LocationMode::Intermediate, - LocationMode::FragmentOutput, - ), - crate::ShaderStage::Compute { .. } => { - ("kernel", LocationMode::Uniform, LocationMode::Uniform) - } - }; - let location_input_name = fun.name.or_index(InputStructIndex(fun_handle)); + match stage { + crate::ShaderStage::Vertex | crate::ShaderStage::Fragment => { + // make dedicated input/output structs + writeln!(self.out, "struct {} {{", location_input_name)?; - match stage { - crate::ShaderStage::Compute { .. } => { - writeln!(self.out, "struct {} {{", location_input_name)?; - for ((handle, var), &usage) in - module.global_variables.iter().zip(&fun.global_usage) + for ((handle, var), &usage) in + module.global_variables.iter().zip(&fun.global_usage) + { + if var.class != crate::StorageClass::Input + || !usage.contains(crate::GlobalUse::LOAD) { - if var.class != crate::StorageClass::Input - || !usage.contains(crate::GlobalUse::LOAD) - { - continue; + continue; + } + // if it's a struct, lift all the built-in contents up to the root + let ty_handle = var.ty; + if let crate::TypeInner::Struct { ref members } = + module.types[ty_handle].inner + { + for (index, member) in members.iter().enumerate() { + if let crate::MemberOrigin::BuiltIn(built_in) = member.origin { + let name = member.name.or_index(MemberIndex(index)); + let ty_name = module.types[member.ty].name.or_index(member.ty); + write!(self.out, "\t{} {}", ty_name, name)?; + ResolvedBinding::BuiltIn(built_in) + .try_fmt_decorated(&mut self.out, ";\n")?; + } } - // if it's a struct, lift all the built-in contents up to the root - let ty_handle = var.ty; - if let crate::TypeInner::Struct { ref members } = - module.types[ty_handle].inner - { - for (index, member) in members.iter().enumerate() { - if let crate::MemberOrigin::BuiltIn(built_in) = member.origin { - let name = member.name.or_index(MemberIndex(index)); - let ty_name = - module.types[member.ty].name.or_index(member.ty); + } else if let Some(ref binding @ crate::Binding::Location(_)) = var.binding + { + let tyvar = TypedGlobalVariable { + module, + handle, + usage: crate::GlobalUse::empty(), + }; + let resolved = options.resolve_binding(binding, in_mode)?; + + write!(self.out, "\t")?; + tyvar.try_fmt(&mut self.out)?; + resolved.try_fmt_decorated(&mut self.out, ";\n")?; + } + } + writeln!(self.out, "}};")?; + + writeln!(self.out, "struct {} {{", output_name)?; + for ((handle, var), &usage) in + module.global_variables.iter().zip(&fun.global_usage) + { + if var.class != crate::StorageClass::Output + || !usage.contains(crate::GlobalUse::STORE) + { + continue; + } + // if it's a struct, lift all the built-in contents up to the root + let ty_handle = var.ty; + if let crate::TypeInner::Struct { ref members } = + module.types[ty_handle].inner + { + for (index, member) in members.iter().enumerate() { + let name = member.name.or_index(MemberIndex(index)); + let ty_name = module.types[member.ty].name.or_index(member.ty); + match member.origin { + crate::MemberOrigin::Empty => {} + crate::MemberOrigin::BuiltIn(built_in) => { write!(self.out, "\t{} {}", ty_name, name)?; ResolvedBinding::BuiltIn(built_in) .try_fmt_decorated(&mut self.out, ";\n")?; } - } - } else if let Some(ref binding @ crate::Binding::Location(_)) = - var.binding - { - let tyvar = TypedGlobalVariable { - module, - handle, - usage: crate::GlobalUse::empty(), - }; - let resolved = options.resolve_binding(binding, in_mode)?; - - write!(self.out, "\t")?; - tyvar.try_fmt(&mut self.out)?; - resolved.try_fmt_decorated(&mut self.out, ";\n")?; - } - } - writeln!(self.out, "}};")?; - writeln!(self.out, "struct {} {{", output_name)?; - for ((handle, var), &usage) in - module.global_variables.iter().zip(&fun.global_usage) - { - if var.class != crate::StorageClass::Output - || !usage.contains(crate::GlobalUse::STORE) - { - continue; - } - // if it's a struct, lift all the built-in contents up to the root - let ty_handle = var.ty; - if let crate::TypeInner::Struct { ref members } = - module.types[ty_handle].inner - { - for (index, member) in members.iter().enumerate() { - let name = member.name.or_index(MemberIndex(index)); - let ty_name = module.types[member.ty].name.or_index(member.ty); - match member.origin { - crate::MemberOrigin::Empty => {} - crate::MemberOrigin::BuiltIn(built_in) => { - write!(self.out, "\t{} {}", ty_name, name)?; - ResolvedBinding::BuiltIn(built_in) - .try_fmt_decorated(&mut self.out, ";\n")?; - } - crate::MemberOrigin::Offset(_) => { - //TODO - } + crate::MemberOrigin::Offset(_) => { + //TODO } } - } else { - let tyvar = TypedGlobalVariable { - module, - handle, - usage: crate::GlobalUse::empty(), - }; - write!(self.out, "\t")?; - tyvar.try_fmt(&mut self.out)?; - if let Some(ref binding) = var.binding { - let resolved = options.resolve_binding(binding, out_mode)?; - resolved.try_fmt_decorated(&mut self.out, "")?; - } - writeln!(self.out, ";")?; } + } else { + let tyvar = TypedGlobalVariable { + module, + handle, + usage: crate::GlobalUse::empty(), + }; + write!(self.out, "\t")?; + tyvar.try_fmt(&mut self.out)?; + if let Some(ref binding) = var.binding { + let resolved = options.resolve_binding(binding, out_mode)?; + resolved.try_fmt_decorated(&mut self.out, "")?; + } + writeln!(self.out, ";")?; } - writeln!(self.out, "}};")?; - writeln!(self.out, "{} {} {}(", em_str, output_name, fun_name)?; - let separator = separate(last_used_global.is_none()); - writeln!( - self.out, - "\t{} {} [[stage_in]]{}", - location_input_name, LOCATION_INPUT_STRUCT_NAME, separator - )?; } - _ => { - writeln!(self.out, "{} void {}(", em_str, fun_name)?; - } - }; + writeln!(self.out, "}};")?; - for ((handle, var), &usage) in module.global_variables.iter().zip(&fun.global_usage) - { - if usage.is_empty() || var.class == crate::StorageClass::Output { + writeln!(self.out, "{} {} {}(", em_str, output_name, name)?; + let separator = separate(last_used_global.is_none()); + writeln!( + self.out, + "\t{} {} [[stage_in]]{}", + location_input_name, LOCATION_INPUT_STRUCT_NAME, separator + )?; + } + crate::ShaderStage::Compute => { + writeln!(self.out, "{} void {}(", em_str, name)?; + } + }; + + for ((handle, var), &usage) in module.global_variables.iter().zip(&fun.global_usage) { + if usage.is_empty() || var.class == crate::StorageClass::Output { + continue; + } + if var.class == crate::StorageClass::Input { + if let Some(crate::Binding::Location(_)) = var.binding { + // location inputs are put into a separate struct continue; } - if var.class == crate::StorageClass::Input { - if let Some(crate::Binding::Location(_)) = var.binding { - // location inputs are put into a separate struct - continue; - } + } + let loc_mode = match (stage, var.class) { + (crate::ShaderStage::Vertex, crate::StorageClass::Input) => { + LocationMode::VertexInput } - let loc_mode = match (stage, var.class) { - (crate::ShaderStage::Vertex, crate::StorageClass::Input) => { - LocationMode::VertexInput - } - (crate::ShaderStage::Vertex, crate::StorageClass::Output) - | (crate::ShaderStage::Fragment { .. }, crate::StorageClass::Input) => { - LocationMode::Intermediate - } - (crate::ShaderStage::Fragment { .. }, crate::StorageClass::Output) => { - LocationMode::FragmentOutput - } - _ => LocationMode::Uniform, - }; - let resolved = - options.resolve_binding(var.binding.as_ref().unwrap(), loc_mode)?; - let tyvar = TypedGlobalVariable { - module, - handle, - usage, - }; - let separator = separate(last_used_global == Some(handle)); - write!(self.out, "\t")?; - tyvar.try_fmt(&mut self.out)?; - resolved.try_fmt_decorated(&mut self.out, separator)?; - writeln!(self.out)?; - } - } else { - let result_type_name = match fun.return_type { - Some(type_id) => module.types[type_id].name.or_index(type_id), - None => Name { - class: "", - source: NameSource::Custom { - name: "void", - prefix: false, - }, - }, + (crate::ShaderStage::Vertex, crate::StorageClass::Output) + | (crate::ShaderStage::Fragment { .. }, crate::StorageClass::Input) => { + LocationMode::Intermediate + } + (crate::ShaderStage::Fragment { .. }, crate::StorageClass::Output) => { + LocationMode::FragmentOutput + } + _ => LocationMode::Uniform, }; - writeln!(self.out, "{} {}(", result_type_name, fun_name)?; - for (index, &ty) in fun.parameter_types.iter().enumerate() { - let name = Name::from(ParameterIndex(index)); - let member_type_name = module.types[ty].name.or_index(ty); - let separator = separate( - index + 1 == fun.parameter_types.len() && last_used_global.is_none(), - ); - writeln!(self.out, "\t{} {}{}", member_type_name, name, separator)?; - } + let resolved = options.resolve_binding(var.binding.as_ref().unwrap(), loc_mode)?; + let tyvar = TypedGlobalVariable { + module, + handle, + usage, + }; + let separator = separate(last_used_global == Some(handle)); + write!(self.out, "\t")?; + tyvar.try_fmt(&mut self.out)?; + resolved.try_fmt_decorated(&mut self.out, separator)?; + writeln!(self.out)?; } writeln!(self.out, ") {{")?; - // write down function body - let has_output = match shader_stage { - Some(crate::ShaderStage::Vertex) | Some(crate::ShaderStage::Fragment { .. }) => { + let has_output = match stage { + crate::ShaderStage::Vertex | crate::ShaderStage::Fragment => { writeln!(self.out, "\t{} {};", output_name, OUTPUT_STRUCT_NAME)?; true } - Some(crate::ShaderStage::Compute { .. }) | None => false, + crate::ShaderStage::Compute => false, }; for (local_handle, local) in fun.local_variables.iter() { let ty_name = module.types[local.ty].name.or_index(local.ty); diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index 91cbc9f87c..b757aab8db 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -248,26 +248,121 @@ impl Writer { } } + fn write_function( + &mut self, + ir_function: &crate::Function, + ir_module: &crate::Module, + ) -> spirv::Word { + let mut function = Function::new(); + + for (_, variable) in ir_function.local_variables.iter() { + let id = self.generate_id(); + + let init_word = match variable.init { + Some(exp) => match &ir_function.expressions[exp] { + crate::Expression::Constant(handle) => { + Some(self.get_constant_id(*handle, ir_module)) + } + _ => unreachable!(), + }, + None => None, + }; + + let pointer_id = + self.get_pointer_id(&ir_module.types, variable.ty, spirv::StorageClass::Function); + function.variables.push(LocalVariable { + id, + name: variable.name.clone(), + instruction: super::instructions::instruction_variable( + pointer_id, + id, + spirv::StorageClass::Function, + init_word, + ), + }); + } + + let return_type_id = + self.get_function_return_type(ir_function.return_type, &ir_module.types); + let mut parameter_type_ids = Vec::with_capacity(ir_function.parameter_types.len()); + + let mut function_parameter_pointer_ids = vec![]; + + for parameter_type in ir_function.parameter_types.iter() { + let id = self.generate_id(); + let pointer_id = self.get_pointer_id( + &ir_module.types, + *parameter_type, + spirv::StorageClass::Function, + ); + + function_parameter_pointer_ids.push(pointer_id); + parameter_type_ids + .push(self.get_type_id(&ir_module.types, LookupType::Handle(*parameter_type))); + function + .parameters + .push(super::instructions::instruction_function_parameter( + pointer_id, id, + )); + } + + let lookup_function_type = LookupFunctionType { + return_type_id, + parameter_type_ids, + }; + + let id = self.generate_id(); + let function_type = + self.get_function_type(lookup_function_type, function_parameter_pointer_ids); + function.signature = Some(super::instructions::instruction_function( + return_type_id, + id, + spirv::FunctionControl::empty(), + function_type, + )); + + let mut block = Block::new(); + let id = self.generate_id(); + block.label = Some(super::instructions::instruction_label(id)); + for statement in ir_function.body.iter() { + self.write_function_statement( + ir_module, + ir_function, + &statement, + &mut block, + &mut function, + ); + } + function.blocks.push(block); + + function.to_words(&mut self.logical_layout.function_definitions); + super::instructions::instruction_function_end() + .to_words(&mut self.logical_layout.function_definitions); + + id + } + // TODO Move to instructions module fn write_entry_point( &mut self, entry_point: &crate::EntryPoint, + stage: crate::ShaderStage, + name: &str, ir_module: &crate::Module, ) -> Instruction { - let function_id = *self.lookup_function.get(&entry_point.function).unwrap(); + let function_id = self.write_function(&entry_point.function, ir_module); - let exec_model = match entry_point.stage { + let exec_model = match stage { crate::ShaderStage::Vertex => spirv::ExecutionModel::Vertex, crate::ShaderStage::Fragment { .. } => spirv::ExecutionModel::Fragment, crate::ShaderStage::Compute { .. } => spirv::ExecutionModel::GLCompute, }; let mut interface_ids = vec![]; - let function = &ir_module.functions[entry_point.function]; for ((handle, _), &usage) in ir_module .global_variables .iter() - .zip(&function.global_usage) + .zip(&entry_point.function.global_usage) { if usage.contains(crate::GlobalUse::STORE) || usage.contains(crate::GlobalUse::LOAD) { let id = self.get_global_variable_id( @@ -280,28 +375,26 @@ impl Writer { } self.try_add_capabilities(exec_model.required_capabilities()); - match entry_point.stage { + match stage { crate::ShaderStage::Vertex => {} - crate::ShaderStage::Fragment { .. } => { + crate::ShaderStage::Fragment => { let execution_mode = spirv::ExecutionMode::OriginUpperLeft; self.try_add_capabilities(execution_mode.required_capabilities()); super::instructions::instruction_execution_mode(function_id, execution_mode) .to_words(&mut self.logical_layout.execution_modes); } - crate::ShaderStage::Compute { .. } => {} + crate::ShaderStage::Compute => {} } if self.writer_flags.contains(WriterFlags::DEBUG) { - self.debugs.push(super::instructions::instruction_name( - function_id, - entry_point.name.as_str(), - )); + self.debugs + .push(super::instructions::instruction_name(function_id, name)); } super::instructions::instruction_entry_point( exec_model, function_id, - entry_point.name.as_str(), + name, interface_ids.as_slice(), ) } @@ -604,27 +697,27 @@ impl Writer { } } - match global_variable.binding.as_ref().unwrap() { + match *global_variable.binding.as_ref().unwrap() { crate::Binding::Location(location) => { self.annotations .push(super::instructions::instruction_decorate( id, spirv::Decoration::Location, - &[*location], + &[location], )); } - crate::Binding::Descriptor { set, binding } => { + crate::Binding::Resource { group, binding } => { self.annotations .push(super::instructions::instruction_decorate( id, spirv::Decoration::DescriptorSet, - &[*set], + &[group], )); self.annotations .push(super::instructions::instruction_decorate( id, spirv::Decoration::Binding, - &[*binding], + &[binding], )); } crate::Binding::BuiltIn(built_in) => { @@ -1000,100 +1093,12 @@ impl Writer { } for (handle, ir_function) in ir_module.functions.iter() { - let mut function = Function::new(); - - for (_, variable) in ir_function.local_variables.iter() { - let id = self.generate_id(); - - let init_word = match variable.init { - Some(exp) => match &ir_function.expressions[exp] { - crate::Expression::Constant(handle) => { - Some(self.get_constant_id(*handle, ir_module)) - } - _ => unreachable!(), - }, - None => None, - }; - - let pointer_id = self.get_pointer_id( - &ir_module.types, - variable.ty, - spirv::StorageClass::Function, - ); - function.variables.push(LocalVariable { - id, - name: variable.name.clone(), - instruction: super::instructions::instruction_variable( - pointer_id, - id, - spirv::StorageClass::Function, - init_word, - ), - }); - } - - let return_type_id = - self.get_function_return_type(ir_function.return_type, &ir_module.types); - let mut parameter_type_ids = Vec::with_capacity(ir_function.parameter_types.len()); - - let mut function_parameter_pointer_ids = vec![]; - - for parameter_type in ir_function.parameter_types.iter() { - let id = self.generate_id(); - let pointer_id = self.get_pointer_id( - &ir_module.types, - *parameter_type, - spirv::StorageClass::Function, - ); - - function_parameter_pointer_ids.push(pointer_id); - parameter_type_ids - .push(self.get_type_id(&ir_module.types, LookupType::Handle(*parameter_type))); - function - .parameters - .push(super::instructions::instruction_function_parameter( - pointer_id, id, - )); - } - - let lookup_function_type = LookupFunctionType { - return_type_id, - parameter_type_ids, - }; - - let id = self.generate_id(); - let function_type = - self.get_function_type(lookup_function_type, function_parameter_pointer_ids); - function.signature = Some(super::instructions::instruction_function( - return_type_id, - id, - spirv::FunctionControl::empty(), - function_type, - )); - + let id = self.write_function(ir_function, ir_module); self.lookup_function.insert(handle, id); - - let mut block = Block::new(); - let id = self.generate_id(); - block.label = Some(super::instructions::instruction_label(id)); - for statement in ir_function.body.iter() { - self.write_function_statement( - ir_module, - ir_function, - &statement, - &mut block, - &mut function, - ); - } - function.blocks.push(block); - - function.to_words(&mut self.logical_layout.function_definitions); - super::instructions::instruction_function_end() - .to_words(&mut self.logical_layout.function_definitions); } - for entry_point in ir_module.entry_points.iter() { - let entry_point_instruction = self.write_entry_point(entry_point, ir_module); + for (&(stage, ref name), ir_ep) in ir_module.entry_points.iter() { + let entry_point_instruction = self.write_entry_point(ir_ep, stage, name, ir_module); entry_point_instruction.to_words(&mut self.logical_layout.entry_points); } diff --git a/src/front/glsl/ast.rs b/src/front/glsl/ast.rs index 2919ef503d..2a32234e41 100644 --- a/src/front/glsl/ast.rs +++ b/src/front/glsl/ast.rs @@ -1,6 +1,6 @@ use crate::{ - Arena, BinaryOperator, Binding, Constant, Expression, FastHashMap, Function, GlobalVariable, - Handle, Interpolation, LocalVariable, ShaderStage, Statement, StorageClass, Type, + Arena, BinaryOperator, Binding, Expression, FastHashMap, Function, GlobalVariable, Handle, + Interpolation, LocalVariable, Module, ShaderStage, Statement, StorageClass, Type, }; #[derive(Debug)] @@ -8,28 +8,23 @@ pub struct Program { pub version: u16, pub profile: Profile, pub shader_stage: ShaderStage, + pub entry: Option, pub lookup_function: FastHashMap>, - pub functions: Arena, pub lookup_type: FastHashMap>, - pub types: Arena, - pub constants: Arena, - pub global_variables: Arena, pub lookup_global_variables: FastHashMap>, pub context: Context, + pub module: Module, } impl Program { - pub fn new(shader_stage: ShaderStage) -> Program { + pub fn new(shader_stage: ShaderStage, entry: String) -> Program { Program { version: 0, profile: Profile::Core, shader_stage, + entry: Some(entry), lookup_function: FastHashMap::default(), - functions: Arena::::new(), lookup_type: FastHashMap::default(), - types: Arena::::new(), - constants: Arena::::new(), - global_variables: Arena::::new(), lookup_global_variables: FastHashMap::default(), context: Context { expressions: Arena::::new(), @@ -37,6 +32,7 @@ impl Program { scopes: vec![FastHashMap::default()], lookup_global_var_exps: FastHashMap::default(), }, + module: Module::generate_empty(), } } diff --git a/src/front/glsl/mod.rs b/src/front/glsl/mod.rs index 338a21f680..ca0ddaad8c 100644 --- a/src/front/glsl/mod.rs +++ b/src/front/glsl/mod.rs @@ -1,4 +1,4 @@ -use crate::{EntryPoint, Module, ShaderStage}; +use crate::{Module, ShaderStage}; mod lex; #[cfg(test)] @@ -23,7 +23,7 @@ mod rosetta_tests; pub fn parse_str(source: &str, entry: String, stage: ShaderStage) -> Result { log::debug!("------ GLSL-pomelo ------"); - let mut program = Program::new(stage); + let mut program = Program::new(stage, entry); let lex = Lexer::new(source); let mut parser = parser::Parser::new(&mut program); @@ -32,20 +32,5 @@ pub fn parse_str(source: &str, entry: String, stage: ShaderStage) -> Result Result { - let mut program = Program::new(stage); + let mut program = Program::new(stage, "".to_string()); let lex = Lexer::new(source); let mut parser = parser::Parser::new(&mut program); diff --git a/src/front/glsl/variables.rs b/src/front/glsl/variables.rs index 263bca52e3..834f787651 100644 --- a/src/front/glsl/variables.rs +++ b/src/front/glsl/variables.rs @@ -18,25 +18,28 @@ impl Program { return Err(ErrorKind::VariableNotAvailable(name.into())); } }; - let h = self.global_variables.fetch_or_append(GlobalVariable { - name: Some(name.into()), - class: if self.shader_stage == ShaderStage::Vertex { - StorageClass::Output - } else { - StorageClass::Input - }, - binding: Some(Binding::BuiltIn(BuiltIn::Position)), - ty: self.types.fetch_or_append(Type { - name: None, - inner: TypeInner::Vector { - size: VectorSize::Quad, - kind: ScalarKind::Float, - width: 4, + let h = self + .module + .global_variables + .fetch_or_append(GlobalVariable { + name: Some(name.into()), + class: if self.shader_stage == ShaderStage::Vertex { + StorageClass::Output + } else { + StorageClass::Input }, - }), - interpolation: None, - storage_access: StorageAccess::empty(), - }); + binding: Some(Binding::BuiltIn(BuiltIn::Position)), + ty: self.module.types.fetch_or_append(Type { + name: None, + inner: TypeInner::Vector { + size: VectorSize::Quad, + kind: ScalarKind::Float, + width: 4, + }, + }), + interpolation: None, + storage_access: StorageAccess::empty(), + }); self.lookup_global_variables.insert(name.into(), h); let exp = self .context @@ -54,20 +57,23 @@ impl Program { return Err(ErrorKind::VariableNotAvailable(name.into())); } }; - let h = self.global_variables.fetch_or_append(GlobalVariable { - name: Some(name.into()), - class: StorageClass::Input, - binding: Some(Binding::BuiltIn(BuiltIn::VertexIndex)), - ty: self.types.fetch_or_append(Type { - name: None, - inner: TypeInner::Scalar { - kind: ScalarKind::Uint, - width: 4, - }, - }), - interpolation: None, - storage_access: StorageAccess::empty(), - }); + let h = self + .module + .global_variables + .fetch_or_append(GlobalVariable { + name: Some(name.into()), + class: StorageClass::Input, + binding: Some(Binding::BuiltIn(BuiltIn::VertexIndex)), + ty: self.module.types.fetch_or_append(Type { + name: None, + inner: TypeInner::Scalar { + kind: ScalarKind::Uint, + width: 4, + }, + }), + interpolation: None, + storage_access: StorageAccess::empty(), + }); self.lookup_global_variables.insert(name.into(), h); let exp = self .context diff --git a/src/front/mod.rs b/src/front/mod.rs index 3e74645776..cb543c72a9 100644 --- a/src/front/mod.rs +++ b/src/front/mod.rs @@ -19,7 +19,7 @@ impl crate::Module { constants: Arena::new(), global_variables: Arena::new(), functions: Arena::new(), - entry_points: Vec::new(), + entry_points: crate::FastHashMap::default(), } } diff --git a/src/front/spv/function.rs b/src/front/spv/function.rs index 06de75e41f..335e1a33e9 100644 --- a/src/front/spv/function.rs +++ b/src/front/spv/function.rs @@ -49,123 +49,142 @@ pub enum Terminator { Unreachable, } -pub fn parse_function>( - parser: &mut super::Parser, - inst: Instruction, - module: &mut crate::Module, -) -> Result<(), Error> { - parser.switch(ModuleState::Function, inst.op)?; - inst.expect(5)?; - let result_type = parser.next()?; - let fun_id = parser.next()?; - let _fun_control = parser.next()?; - let fun_type = parser.next()?; - let mut fun = { - let ft = parser.lookup_function_type.lookup(fun_type)?; - if ft.return_type_id != result_type { - return Err(Error::WrongFunctionResultType(result_type)); - } - crate::Function { - name: parser.future_decor.remove(&fun_id).and_then(|dec| dec.name), - parameter_types: Vec::with_capacity(ft.parameter_type_ids.len()), - return_type: if parser.lookup_void_type.contains(&result_type) { - None - } else { - Some(parser.lookup_type.lookup(result_type)?.handle) - }, - global_usage: Vec::new(), - local_variables: Arena::new(), - expressions: parser.make_expression_storage(), - body: Vec::new(), - } - }; +impl> super::Parser { + pub fn parse_function( + &mut self, + inst: Instruction, + module: &mut crate::Module, + ) -> Result<(), Error> { + self.switch(ModuleState::Function, inst.op)?; + inst.expect(5)?; + let result_type = self.next()?; + let fun_id = self.next()?; + let _fun_control = self.next()?; + let fun_type = self.next()?; - // read parameters - for i in 0..fun.parameter_types.capacity() { - match parser.next_inst()? { - Instruction { - op: spirv::Op::FunctionParameter, - wc: 3, - } => { - let type_id = parser.next()?; - let id = parser.next()?; - let handle = fun - .expressions - .append(crate::Expression::FunctionParameter(i as u32)); - parser - .lookup_expression - .insert(id, LookupExpression { type_id, handle }); - //Note: we redo the lookup in order to work around `parser` borrowing + let mut fun = { + let ft = self.lookup_function_type.lookup(fun_type)?; + if ft.return_type_id != result_type { + return Err(Error::WrongFunctionResultType(result_type)); + } + crate::Function { + name: self.future_decor.remove(&fun_id).and_then(|dec| dec.name), + parameter_types: Vec::with_capacity(ft.parameter_type_ids.len()), + return_type: if self.lookup_void_type.contains(&result_type) { + None + } else { + Some(self.lookup_type.lookup(result_type)?.handle) + }, + global_usage: Vec::new(), + local_variables: Arena::new(), + expressions: self.make_expression_storage(), + body: Vec::new(), + } + }; - if type_id - != parser - .lookup_function_type - .lookup(fun_type)? - .parameter_type_ids[i] - { - return Err(Error::WrongFunctionParameterType(type_id)); + // read parameters + for i in 0..fun.parameter_types.capacity() { + match self.next_inst()? { + Instruction { + op: spirv::Op::FunctionParameter, + wc: 3, + } => { + let type_id = self.next()?; + let id = self.next()?; + let handle = fun + .expressions + .append(crate::Expression::FunctionParameter(i as u32)); + self.lookup_expression + .insert(id, LookupExpression { type_id, handle }); + //Note: we redo the lookup in order to work around `self` borrowing + + if type_id + != self + .lookup_function_type + .lookup(fun_type)? + .parameter_type_ids[i] + { + return Err(Error::WrongFunctionParameterType(type_id)); + } + let ty = self.lookup_type.lookup(type_id)?.handle; + fun.parameter_types.push(ty); } - let ty = parser.lookup_type.lookup(type_id)?.handle; - fun.parameter_types.push(ty); - } - Instruction { op, .. } => return Err(Error::InvalidParameter(op)), - } - } - - // Read body - let mut local_function_calls = FastHashMap::default(); - let mut flow_graph = FlowGraph::new(); - - // Scan the blocks and add them as nodes - loop { - let fun_inst = parser.next_inst()?; - log::debug!("{:?}", fun_inst.op); - match fun_inst.op { - spirv::Op::Label => { - // Read the label ID - fun_inst.expect(2)?; - let block_id = parser.next()?; - - let node = parser.next_block( - block_id, - &mut fun.expressions, - &mut fun.local_variables, - &module.types, - &module.constants, - &module.global_variables, - &mut local_function_calls, - )?; - - flow_graph.add_node(node); - } - spirv::Op::FunctionEnd => { - fun_inst.expect(1)?; - break; - } - _ => { - return Err(Error::UnsupportedInstruction(parser.state, fun_inst.op)); + Instruction { op, .. } => return Err(Error::InvalidParameter(op)), } } + + // Read body + let mut local_function_calls = FastHashMap::default(); + let mut flow_graph = FlowGraph::new(); + + // Scan the blocks and add them as nodes + loop { + let fun_inst = self.next_inst()?; + log::debug!("{:?}", fun_inst.op); + match fun_inst.op { + spirv::Op::Label => { + // Read the label ID + fun_inst.expect(2)?; + let block_id = self.next()?; + + let node = self.next_block( + block_id, + &mut fun.expressions, + &mut fun.local_variables, + &module.types, + &module.constants, + &module.global_variables, + &mut local_function_calls, + )?; + + flow_graph.add_node(node); + } + spirv::Op::FunctionEnd => { + fun_inst.expect(1)?; + break; + } + _ => { + return Err(Error::UnsupportedInstruction(self.state, fun_inst.op)); + } + } + } + + flow_graph.classify(); + flow_graph.remove_phi_instructions(&self.lookup_expression); + fun.body = flow_graph.to_naga()?; + + // done + fun.fill_global_use(&module.global_variables); + + let source = match self.lookup_entry_point.remove(&fun_id) { + Some(ep) => { + module.entry_points.insert( + (ep.stage, ep.name.clone()), + crate::EntryPoint { + early_depth_test: ep.early_depth_test, + workgroup_size: ep.workgroup_size, + function: fun, + }, + ); + DeferredSource::EntryPoint(ep.stage, ep.name) + } + None => { + let handle = module.functions.append(fun); + self.lookup_function.insert(fun_id, handle); + DeferredSource::Function(handle) + } + }; + + for (expr_handle, dst_id) in local_function_calls { + self.deferred_function_calls.push(DeferredFunctionCall { + source: source.clone(), + expr_handle, + dst_id, + }); + } + + self.lookup_expression.clear(); + self.lookup_sampled_image.clear(); + Ok(()) } - - flow_graph.classify(); - flow_graph.remove_phi_instructions(&parser.lookup_expression); - fun.body = flow_graph.to_naga()?; - - // done - fun.global_usage = - crate::GlobalUse::scan(&fun.expressions, &fun.body, &module.global_variables); - let handle = module.functions.append(fun); - for (expr_handle, dst_id) in local_function_calls { - parser.deferred_function_calls.push(DeferredFunctionCall { - source_handle: handle, - expr_handle, - dst_id, - }); - } - - parser.lookup_function.insert(fun_id, handle); - parser.lookup_expression.clear(); - parser.lookup_sampled_image.clear(); - Ok(()) } diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index a9115f9c57..c4c75f69e8 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -197,10 +197,10 @@ impl Decoration { Decoration { built_in: None, location: None, - desc_set: Some(set), + desc_set: Some(group), desc_index: Some(binding), .. - } => Some(crate::Binding::Descriptor { set, binding }), + } => Some(crate::Binding::Resource { group, binding }), _ => None, } } @@ -252,6 +252,8 @@ struct LookupFunctionType { struct EntryPoint { stage: crate::ShaderStage, name: String, + early_depth_test: Option, + workgroup_size: [u32; 3], function_id: spirv::Word, variable_ids: Vec, } @@ -285,8 +287,13 @@ struct LookupSampledImage { image: Handle, sampler: Handle, } +#[derive(Clone, Debug)] +enum DeferredSource { + EntryPoint(crate::ShaderStage, String), + Function(Handle), +} struct DeferredFunctionCall { - source_handle: Handle, + source: DeferredSource, expr_handle: Handle, dst_id: spirv::Word, } @@ -1444,7 +1451,7 @@ impl> Parser { Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module), Op::ConstantComposite => self.parse_composite_constant(inst, &mut module), Op::Variable => self.parse_global_variable(inst, &mut module), - Op::Function => parse_function(&mut self, inst, &mut module), + Op::Function => self.parse_function(inst, &mut module), _ => Err(Error::UnsupportedInstruction(self.state, inst.op)), //TODO }?; } @@ -1476,12 +1483,17 @@ impl> Parser { for dfc in self.deferred_function_calls.drain(..) { let dst_handle = *self.lookup_function.lookup(dfc.dst_id)?; - match *module - .functions - .get_mut(dfc.source_handle) - .expressions - .get_mut(dfc.expr_handle) - { + let fun = match dfc.source { + DeferredSource::Function(fun_handle) => module.functions.get_mut(fun_handle), + DeferredSource::EntryPoint(stage, name) => { + &mut module + .entry_points + .get_mut(&(stage, name)) + .unwrap() + .function + } + }; + match *fun.expressions.get_mut(dfc.expr_handle) { crate::Expression::Call { ref mut origin, arguments: _, @@ -1499,15 +1511,6 @@ impl> Parser { self.future_member_decor.clear(); } - module.entry_points.reserve(self.lookup_entry_point.len()); - for raw in self.lookup_entry_point.values() { - module.entry_points.push(crate::EntryPoint { - stage: raw.stage, - name: raw.name.clone(), - function: *self.lookup_function.lookup(raw.function_id)?, - }); - } - Ok(module) } @@ -1570,15 +1573,13 @@ impl> Parser { let ep = EntryPoint { stage: match exec_model { spirv::ExecutionModel::Vertex => crate::ShaderStage::Vertex, - spirv::ExecutionModel::Fragment => crate::ShaderStage::Fragment { - early_depth_test: None, - }, - spirv::ExecutionModel::GLCompute => crate::ShaderStage::Compute { - local_size: (0, 0, 0), - }, + spirv::ExecutionModel::Fragment => crate::ShaderStage::Fragment, + spirv::ExecutionModel::GLCompute => crate::ShaderStage::Compute, _ => return Err(Error::UnsupportedExecutionModel(exec_model as u32)), }, name, + early_depth_test: None, + workgroup_size: [0; 3], function_id, variable_ids: self.data.by_ref().take(left as usize).collect(), }; @@ -1587,7 +1588,6 @@ impl> Parser { } fn parse_execution_mode(&mut self, inst: Instruction) -> Result<(), Error> { - use crate::ShaderStage; use spirv::ExecutionMode; self.switch(ModuleState::ExecutionMode, inst.op)?; @@ -1597,63 +1597,47 @@ impl> Parser { let mode_id = self.next()?; let args: Vec = self.data.by_ref().take(inst.wc as usize - 3).collect(); - let ep: &mut EntryPoint = self + let ep = self .lookup_entry_point .get_mut(&ep_id) .ok_or(Error::InvalidId(ep_id))?; let mode = spirv::ExecutionMode::from_u32(mode_id) .ok_or(Error::UnsupportedExecutionMode(mode_id))?; - match ep.stage { - ShaderStage::Fragment { - ref mut early_depth_test, - } => { - match mode { - ExecutionMode::EarlyFragmentTests => { - if early_depth_test.is_none() { - *early_depth_test = Some(crate::EarlyDepthTest { conservative: None }); - } - } - ExecutionMode::DepthUnchanged => { - *early_depth_test = Some(crate::EarlyDepthTest { - conservative: Some(crate::ConservativeDepth::Unchanged), - }); - } - ExecutionMode::DepthGreater => { - *early_depth_test = Some(crate::EarlyDepthTest { - conservative: Some(crate::ConservativeDepth::GreaterEqual), - }); - } - ExecutionMode::DepthLess => { - *early_depth_test = Some(crate::EarlyDepthTest { - conservative: Some(crate::ConservativeDepth::LessEqual), - }); - } - ExecutionMode::DepthReplacing => { - // Ignored because it can be deduced from the IR. - } - ExecutionMode::OriginUpperLeft => { - // Ignored because the other option (OriginLowerLeft) is not valid in Vulkan mode. - } - _ => { - return Err(Error::UnsupportedExecutionMode(mode_id)); - } - }; + match mode { + ExecutionMode::EarlyFragmentTests => { + if ep.early_depth_test.is_none() { + ep.early_depth_test = Some(crate::EarlyDepthTest { conservative: None }); + } } - ShaderStage::Compute { ref mut local_size } => { - match mode { - ExecutionMode::LocalSize => { - *local_size = (args[0], args[1], args[2]); - } - _ => { - return Err(Error::UnsupportedExecutionMode(mode_id)); - } - }; + ExecutionMode::DepthUnchanged => { + ep.early_depth_test = Some(crate::EarlyDepthTest { + conservative: Some(crate::ConservativeDepth::Unchanged), + }); + } + ExecutionMode::DepthGreater => { + ep.early_depth_test = Some(crate::EarlyDepthTest { + conservative: Some(crate::ConservativeDepth::GreaterEqual), + }); + } + ExecutionMode::DepthLess => { + ep.early_depth_test = Some(crate::EarlyDepthTest { + conservative: Some(crate::ConservativeDepth::LessEqual), + }); + } + ExecutionMode::DepthReplacing => { + // Ignored because it can be deduced from the IR. + } + ExecutionMode::OriginUpperLeft => { + // Ignored because the other option (OriginLowerLeft) is not valid in Vulkan mode. + } + ExecutionMode::LocalSize => { + ep.workgroup_size = [args[0], args[1], args[2]]; } _ => { return Err(Error::UnsupportedExecutionMode(mode_id)); } - }; + } Ok(()) } diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index f3f980e9b6..ed0feb8311 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -320,12 +320,8 @@ impl Parser { fn get_shader_stage(word: &str) -> Result> { match word { "vertex" => Ok(crate::ShaderStage::Vertex), - "fragment" => Ok(crate::ShaderStage::Fragment { - early_depth_test: None, - }), - "compute" => Ok(crate::ShaderStage::Compute { - local_size: (0, 0, 0), - }), + "fragment" => Ok(crate::ShaderStage::Fragment), + "compute" => Ok(crate::ShaderStage::Compute), _ => Err(Error::UnknownShaderStage(word)), } } @@ -563,7 +559,6 @@ impl Parser { Token::Separator('.') => { let _ = lexer.next(); let name = lexer.next_ident()?; - println!("Resolving '{}' for {:?}", name, ctx.resolve_type(handle)); //TEMP let expression = match *ctx.resolve_type(handle)? { crate::TypeInner::Struct { ref members } => { let index = members @@ -918,7 +913,9 @@ impl Parser { ready = true; } Token::Word("offset") if ready => { + lexer.expect(Token::Paren('('))?; offset = lexer.next_uint_literal()?; + lexer.expect(Token::Paren(')'))?; ready = false; } other => return Err(Error::Unexpected(other)), @@ -1288,7 +1285,7 @@ impl Parser { lexer: &mut Lexer<'a>, module: &mut crate::Module, lookup_global_expression: &FastHashMap<&'a str, crate::Expression>, - ) -> Result, Error<'a>> { + ) -> Result<(crate::Function, &'a str), Error<'a>> { self.scopes.push(Scope::FunctionDecl); // read function name let mut lookup_ident = FastHashMap::default(); @@ -1322,7 +1319,7 @@ impl Parser { Some(self.parse_type_decl(lexer, None, &mut module.types)?.0) }; - let fun_handle = module.functions.append(crate::Function { + let mut fun = crate::Function { name: Some(fun_name.to_string()), parameter_types, return_type, @@ -1330,15 +1327,7 @@ impl Parser { local_variables: Arena::new(), expressions, body: Vec::new(), - }); - if self - .function_lookup - .insert(fun_name.to_string(), fun_handle) - .is_some() - { - return Err(Error::FunctionRedefinition(fun_name)); - } - let fun = module.functions.get_mut(fun_handle); + }; // read body let mut typifier = Typifier::new(); @@ -1356,11 +1345,10 @@ impl Parser { }, )?; // done - fun.global_usage = - crate::GlobalUse::scan(&fun.expressions, &fun.body, &module.global_variables); + fun.fill_global_use(&module.global_variables); self.scopes.pop(); - Ok(fun_handle) + Ok((fun, fun_name)) } fn parse_global_decl<'a>( @@ -1373,30 +1361,56 @@ impl Parser { let mut binding = None; // Perspective is the default qualifier. let mut interpolation = None; + let mut stage = None; + let mut workgroup_size = [1u32; 3]; + if lexer.skip(Token::DoubleParen('[')) { - let (mut bind_index, mut bind_set) = (None, None); + let (mut bind_index, mut bind_group) = (None, None); self.scopes.push(Scope::Decoration); loop { match lexer.next_ident()? { "location" => { + lexer.expect(Token::Paren('('))?; let loc = lexer.next_uint_literal()?; + lexer.expect(Token::Paren(')'))?; binding = Some(crate::Binding::Location(loc)); } "builtin" => { + lexer.expect(Token::Paren('('))?; let builtin = Self::get_built_in(lexer.next_ident()?)?; + lexer.expect(Token::Paren(')'))?; binding = Some(crate::Binding::BuiltIn(builtin)); } "binding" => { + lexer.expect(Token::Paren('('))?; bind_index = Some(lexer.next_uint_literal()?); + lexer.expect(Token::Paren(')'))?; } - "set" => { - bind_set = Some(lexer.next_uint_literal()?); + "group" => { + lexer.expect(Token::Paren('('))?; + bind_group = Some(lexer.next_uint_literal()?); + lexer.expect(Token::Paren(')'))?; } "interpolate" => { - if interpolation.is_some() { - return Err(Error::UnknownDecoration(lexer.next_ident()?)); - } + lexer.expect(Token::Paren('('))?; interpolation = Some(Self::get_interpolation(lexer.next_ident()?)?); + lexer.expect(Token::Paren(')'))?; + } + "stage" => { + lexer.expect(Token::Paren('('))?; + stage = Some(Self::get_shader_stage(lexer.next_ident()?)?); + lexer.expect(Token::Paren(')'))?; + } + "workgroup_size" => { + lexer.expect(Token::Paren('('))?; + for (i, size) in workgroup_size.iter_mut().enumerate() { + *size = lexer.next_uint_literal()?; + match lexer.next() { + Token::Paren(')') => break, + Token::Separator(',') if i != 2 => (), + other => return Err(Error::Unexpected(other)), + } + } } word => return Err(Error::UnknownDecoration(word)), } @@ -1408,15 +1422,11 @@ impl Parser { other => return Err(Error::Unexpected(other)), } } - match (bind_set, bind_index) { - (Some(set), Some(index)) if binding.is_none() => { - binding = Some(crate::Binding::Descriptor { - set, - binding: index, - }); - } - _ if binding.is_none() => return Err(Error::Other), - _ => {} + if let (Some(group), Some(index)) = (bind_group, bind_index) { + binding = Some(crate::Binding::Resource { + group, + binding: index, + }); } self.scopes.pop(); } @@ -1494,31 +1504,30 @@ impl Parser { .insert(pvar.name, crate::Expression::GlobalVariable(var_handle)); } Token::Word("fn") => { - self.parse_function_decl(lexer, module, &lookup_global_expression)?; - } - Token::Word("entry_point") => { - let stage = Self::get_shader_stage(lexer.next_ident()?)?; - let export_name = if lexer.skip(Token::Word("as")) { - match lexer.next() { - Token::String(name) => Some(name), - other => return Err(Error::Unexpected(other)), + let (function, name) = + self.parse_function_decl(lexer, module, &lookup_global_expression)?; + let already_declared = match stage { + Some(stage) => module + .entry_points + .insert( + (stage, name.to_string()), + crate::EntryPoint { + early_depth_test: None, + workgroup_size, + function, + }, + ) + .is_some(), + None => { + let fun_handle = module.functions.append(function); + self.function_lookup + .insert(name.to_string(), fun_handle) + .is_some() } - } else { - None }; - lexer.expect(Token::Operation('='))?; - let fun_ident = lexer.next_ident()?; - lexer.expect(Token::Separator(';'))?; - let (fun_handle, _) = module - .functions - .iter() - .find(|(_, fun)| fun.name.as_deref() == Some(fun_ident)) - .ok_or(Error::UnknownFunction(fun_ident))?; - module.entry_points.push(crate::EntryPoint { - stage, - name: export_name.unwrap_or(fun_ident).to_owned(), - function: fun_handle, - }); + if already_declared { + return Err(Error::FunctionRedefinition(name)); + } } Token::End => return Ok(false), token => return Err(Error::Unexpected(token)), diff --git a/src/lib.rs b/src/lib.rs index 4cded46cac..7086834c2f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -88,18 +88,14 @@ pub enum ConservativeDepth { } /// Stage of the programmable pipeline. -#[derive(Clone, Copy, Debug, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[allow(missing_docs)] // The names are self evident pub enum ShaderStage { Vertex, - Fragment { - early_depth_test: Option, - }, - Compute { - local_size: (u32, u32, u32), - }, + Fragment, + Compute, } /// Class of storage for variables. @@ -416,8 +412,8 @@ pub enum Binding { BuiltIn(BuiltIn), /// Indexed location. Location(u32), - /// Binding within a descriptor set. - Descriptor { set: u32, binding: u32 }, + /// Binding within a resource group. + Resource { group: u32, binding: u32 }, } bitflags::bitflags! { @@ -635,7 +631,7 @@ pub enum Expression { origin: FunctionOrigin, arguments: Vec>, }, - /// Get dynamic array length. + /// Get the length of an array. ArrayLength(Handle), } @@ -718,12 +714,12 @@ pub struct Function { #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub struct EntryPoint { - /// The stage in the programmable pipeline this entry point is for. - pub stage: ShaderStage, - /// Name identifying this entry point. - pub name: String, - /// The function to be used. - pub function: Handle, + /// Early depth test for fragment stages. + pub early_depth_test: Option, + /// Workgroup size for compute stages + pub workgroup_size: [u32; 3], + /// The entrance function. + pub function: Function, } /// Shader module. @@ -751,6 +747,6 @@ pub struct Module { pub global_variables: Arena, /// Storage for the functions defined in this module. pub functions: Arena, - /// Vector of exported entry points. - pub entry_points: Vec, + /// Exported entry points. + pub entry_points: FastHashMap<(ShaderStage, String), EntryPoint>, } diff --git a/src/proc/interface.rs b/src/proc/interface.rs index 8e1f5666ce..102da6e6a3 100644 --- a/src/proc/interface.rs +++ b/src/proc/interface.rs @@ -2,12 +2,13 @@ use crate::arena::{Arena, Handle}; pub struct Interface<'a, T> { pub expressions: &'a Arena, - pub visitor: &'a mut T, + pub visitor: T, } pub trait Visitor { fn visit_expr(&mut self, _: &crate::Expression) {} fn visit_lhs_expr(&mut self, _: &crate::Expression) {} + fn visit_fun(&mut self, _: Handle) {} } impl<'a, T> Interface<'a, T> @@ -95,10 +96,16 @@ where E::Derivative { expr, .. } => { self.traverse_expr(expr); } - E::Call { ref arguments, .. } => { + E::Call { + ref origin, + ref arguments, + } => { for &argument in arguments { self.traverse_expr(argument); } + if let crate::FunctionOrigin::Local(fun) = *origin { + self.visitor.visit_fun(fun); + } } E::ArrayLength(expr) => { self.traverse_expr(expr); @@ -168,9 +175,9 @@ where } } -struct GlobalUseVisitor(Vec); +struct GlobalUseVisitor<'a>(&'a mut [crate::GlobalUse]); -impl Visitor for GlobalUseVisitor { +impl Visitor for GlobalUseVisitor<'_> { fn visit_expr(&mut self, expr: &crate::Expression) { if let crate::Expression::GlobalVariable(handle) = expr { self.0[handle.index()] |= crate::GlobalUse::LOAD; @@ -184,20 +191,17 @@ impl Visitor for GlobalUseVisitor { } } -impl crate::GlobalUse { - pub fn scan( - expressions: &Arena, - body: &[crate::Statement], - globals: &Arena, - ) -> Vec { - let mut visitor = GlobalUseVisitor(vec![crate::GlobalUse::empty(); globals.len()]); +impl crate::Function { + pub fn fill_global_use(&mut self, globals: &Arena) { + self.global_usage.clear(); + self.global_usage + .resize(globals.len(), crate::GlobalUse::empty()); let mut io = Interface { - expressions, - visitor: &mut visitor, + expressions: &self.expressions, + visitor: GlobalUseVisitor(&mut self.global_usage), }; - io.traverse(body); - visitor.0 + io.traverse(&self.body); } } @@ -248,14 +252,25 @@ mod tests { }, ]; + let mut function = crate::Function { + name: None, + parameter_types: Vec::new(), + return_type: None, + local_variables: Arena::new(), + expressions, + global_usage: Vec::new(), + body: test_body, + }; + function.fill_global_use(&test_globals); + assert_eq!( - GlobalUse::scan(&expressions, &test_body, &test_globals), - vec![ + &function.global_usage, + &[ GlobalUse::LOAD, GlobalUse::STORE, GlobalUse::STORE, GlobalUse::LOAD, - ] + ], ) } } diff --git a/test-data/boids.param.ron b/test-data/boids.param.ron index db17ca3410..bfd7487778 100644 --- a/test-data/boids.param.ron +++ b/test-data/boids.param.ron @@ -1,7 +1,7 @@ ( metal_bindings: { - (set: 0, binding: 0): (buffer: Some(0), texture: None, sampler: None, mutable: false), - (set: 0, binding: 1): (buffer: Some(1), texture: None, sampler: None, mutable: true), - (set: 0, binding: 2): (buffer: Some(2), texture: None, sampler: None, mutable: true), + (group: 0, binding: 0): (buffer: Some(0), texture: None, sampler: None, mutable: false), + (group: 0, binding: 1): (buffer: Some(1), texture: None, sampler: None, mutable: true), + (group: 0, binding: 2): (buffer: Some(2), texture: None, sampler: None, mutable: true), } ) diff --git a/test-data/boids.wgsl b/test-data/boids.wgsl index b3aeca6e07..efb4a69684 100644 --- a/test-data/boids.wgsl +++ b/test-data/boids.wgsl @@ -16,12 +16,13 @@ import "GLSL.std.450" as std; # vertex shader -[[location 0]] var a_particlePos : vec2; -[[location 1]] var a_particleVel : vec2; -[[location 2]] var a_pos : vec2; -[[builtin position]] var gl_Position : vec4; +[[location(0)]] var a_particlePos : vec2; +[[location(1)]] var a_particleVel : vec2; +[[location(2)]] var a_pos : vec2; +[[builtin(position)]] var gl_Position : vec4; -fn vtx_main() -> void { +[[stage(vertex)]] +fn main() -> void { var angle : f32 = -std::atan2(a_particleVel.x, a_particleVel.y); var pos : vec2 = vec2( (a_pos.x * std::cos(angle)) - (a_pos.y * std::sin(angle)), @@ -29,45 +30,45 @@ fn vtx_main() -> void { gl_Position = vec4(pos + a_particlePos, 0.0, 1.0); return; } -entry_point vertex as "main" = vtx_main; # fragment shader -[[location 0]] var fragColor : vec4; +[[location(0)]] var fragColor : vec4; -fn frag_main() -> void { +[[stage(fragment)]] +fn main() -> void { fragColor = vec4(1.0, 1.0, 1.0, 1.0); return; } -entry_point fragment as "main" = frag_main; # compute shader type Particle = struct { - [[offset 0]] pos : vec2; - [[offset 8]] vel : vec2; + [[offset(0)]] pos : vec2; + [[offset(8)]] vel : vec2; }; type SimParams = struct { - [[offset 0]] deltaT : f32; - [[offset 4]] rule1Distance : f32; - [[offset 8]] rule2Distance : f32; - [[offset 12]] rule3Distance : f32; - [[offset 16]] rule1Scale : f32; - [[offset 20]] rule2Scale : f32; - [[offset 24]] rule3Scale : f32; + [[offset(0)]] deltaT : f32; + [[offset(4)]] rule1Distance : f32; + [[offset(8)]] rule2Distance : f32; + [[offset(12)]] rule3Distance : f32; + [[offset(16)]] rule1Scale : f32; + [[offset(20)]] rule2Scale : f32; + [[offset(24)]] rule3Scale : f32; }; type Particles = struct { - [[offset 0]] particles : [[stride 16]] array; + [[offset(0)]] particles : [[stride 16]] array; }; -[[binding 0, set 0]] var params : SimParams; -[[binding 1, set 0]] var particlesA : Particles; -[[binding 2, set 0]] var particlesB : Particles; +[[group(0), binding(0)]] var params : SimParams; +[[group(0), binding(1)]] var particlesA : Particles; +[[group(0), binding(2)]] var particlesB : Particles; -[[builtin global_invocation_id]] var gl_GlobalInvocationID : vec3; +[[builtin(global_invocation_id)]] var gl_GlobalInvocationID : vec3; # https://github.com/austinEng/Project6-Vulkan-Flocking/blob/master/data/shaders/computeparticles/particle.comp -fn compute_main() -> void { +[[stage(compute)]] +fn main() -> void { var index : u32 = gl_GlobalInvocationID.x; if (index >= 5) { return; @@ -148,5 +149,3 @@ fn compute_main() -> void { return; } -entry_point compute as "main" = compute_main; - diff --git a/test-data/function.wgsl b/test-data/function.wgsl index 7252e1cd1b..5a968cadd4 100644 --- a/test-data/function.wgsl +++ b/test-data/function.wgsl @@ -4,9 +4,8 @@ fn test_function(test: f32) -> f32 { return test; } -fn main_vert() -> void { +[[stage(vertex)]] +fn main() -> void { var foo: f32 = std::glsl::distance(0.0, 1.0); var test: f32 = test_function(1.0); } - -entry_point vertex as "main" = main_vert; \ No newline at end of file diff --git a/test-data/quad.wgsl b/test-data/quad.wgsl index 0e2dc08da7..1e83ccd724 100644 --- a/test-data/quad.wgsl +++ b/test-data/quad.wgsl @@ -1,24 +1,24 @@ # vertex const c_scale: f32 = 1.2; -[[location 0]] var a_pos : vec2; -[[location 1]] var a_uv : vec2; -[[location 0]] var v_uv : vec2; -[[builtin position]] var o_position : vec4; +[[location(0)]] var a_pos : vec2; +[[location(1)]] var a_uv : vec2; +[[location(0)]] var v_uv : vec2; +[[builtin(position)]] var o_position : vec4; -fn main_vert() -> void { +[[stage(vertex)]] +fn main() -> void { o_position = vec4(c_scale * a_pos, 0.0, 1.0); return; } -entry_point vertex as "main" = main_vert; # fragment -[[location 0]] var a_uv : vec2; +[[location(0)]] var a_uv : vec2; #layout(set = 0, binding = 0) uniform texture2D u_texture; #layout(set = 0, binding = 1) uniform sampler u_sampler; -[[location 0]] var o_color : vec4; +[[location(0)]] var o_color : vec4; -fn main_frag() -> void { +[[stage(fragment)]] +fn main() -> void { o_color = vec4(1.0, 0.0, 0.0, 1.0); #TODO: sample return; } -entry_point fragment as "main" = main_frag; diff --git a/test-data/simple/simple.expected.ron b/test-data/simple/simple.expected.ron index 5d6863579e..a4b55e4432 100644 --- a/test-data/simple/simple.expected.ron +++ b/test-data/simple/simple.expected.ron @@ -64,58 +64,56 @@ ), ), ], - functions: [ - ( - name: Some("main"), - parameter_types: [], - return_type: None, - global_usage: [ - ( - bits: 1, - ), - ( - bits: 2, - ), - ], - local_variables: [ - ( - name: Some("w"), - ty: 3, - init: Some(3), - ), - ], - expressions: [ - GlobalVariable(1), - GlobalVariable(2), - Constant(1), - LocalVariable(1), - Constant(2), - Compose( - ty: 2, - components: [ - 1, - 5, - 4, - ], - ), - ], - body: [ - Empty, - Store( - pointer: 2, - value: 6, - ), - Return( - value: None, - ), - ], + functions: [], + entry_points: { + (Vertex, "main"): ( + early_depth_test: None, + workgroup_size: (1, 1, 1), + function: ( + name: Some("main"), + parameter_types: [], + return_type: None, + global_usage: [ + ( + bits: 1, + ), + ( + bits: 2, + ), + ], + local_variables: [ + ( + name: Some("w"), + ty: 3, + init: Some(3), + ), + ], + expressions: [ + GlobalVariable(1), + GlobalVariable(2), + Constant(1), + LocalVariable(1), + Constant(2), + Compose( + ty: 2, + components: [ + 1, + 5, + 4, + ], + ), + ], + body: [ + Empty, + Store( + pointer: 2, + value: 6, + ), + Return( + value: None, + ), + ], + ), ), - ], - entry_points: [ - ( - stage: Vertex, - name: "main", - function: 1, - ), - ], + }, ) \ No newline at end of file diff --git a/test-data/simple/simple.wgsl b/test-data/simple/simple.wgsl index 85eef76145..b35b83b426 100644 --- a/test-data/simple/simple.wgsl +++ b/test-data/simple/simple.wgsl @@ -1,10 +1,10 @@ # vertex -[[location 0]] var a_pos : vec2; -[[location 0]] var o_pos : vec4; +[[location(0)]] var a_pos : vec2; +[[location(0)]] var o_pos : vec4; +[[stage(vertex)]] fn main() -> void { var w: f32 = 1.0; o_pos = vec4(a_pos, 0.0, w); return; } -entry_point vertex as "main" = main; diff --git a/tests/convert.rs b/tests/convert.rs index d6a1243c39..97dbe84099 100644 --- a/tests/convert.rs +++ b/tests/convert.rs @@ -33,7 +33,10 @@ fn convert_quad() { use naga::back::msl; let mut binding_map = msl::BindingMap::default(); binding_map.insert( - msl::BindSource { set: 0, binding: 0 }, + msl::BindSource { + group: 0, + binding: 0, + }, msl::BindTarget { buffer: None, texture: Some(1), @@ -42,7 +45,10 @@ fn convert_quad() { }, ); binding_map.insert( - msl::BindSource { set: 0, binding: 1 }, + msl::BindSource { + group: 0, + binding: 1, + }, msl::BindTarget { buffer: None, texture: None, @@ -67,7 +73,10 @@ fn convert_boids() { use naga::back::msl; let mut binding_map = msl::BindingMap::default(); binding_map.insert( - msl::BindSource { set: 0, binding: 0 }, + msl::BindSource { + group: 0, + binding: 0, + }, msl::BindTarget { buffer: Some(0), texture: None, @@ -76,7 +85,10 @@ fn convert_boids() { }, ); binding_map.insert( - msl::BindSource { set: 0, binding: 1 }, + msl::BindSource { + group: 0, + binding: 1, + }, msl::BindTarget { buffer: Some(1), texture: None, @@ -85,7 +97,10 @@ fn convert_boids() { }, ); binding_map.insert( - msl::BindSource { set: 0, binding: 2 }, + msl::BindSource { + group: 0, + binding: 2, + }, msl::BindTarget { buffer: Some(2), texture: None, @@ -113,7 +128,10 @@ fn convert_cube() { use naga::back::msl; let mut binding_map = msl::BindingMap::default(); binding_map.insert( - msl::BindSource { set: 0, binding: 0 }, + msl::BindSource { + group: 0, + binding: 0, + }, msl::BindTarget { buffer: Some(0), texture: None, @@ -122,7 +140,10 @@ fn convert_cube() { }, ); binding_map.insert( - msl::BindSource { set: 0, binding: 1 }, + msl::BindSource { + group: 0, + binding: 1, + }, msl::BindTarget { buffer: None, texture: Some(1), @@ -131,7 +152,10 @@ fn convert_cube() { }, ); binding_map.insert( - msl::BindSource { set: 0, binding: 2 }, + msl::BindSource { + group: 0, + binding: 2, + }, msl::BindTarget { buffer: None, texture: None, @@ -154,9 +178,7 @@ fn convert_phong_lighting() { let module = load_glsl( "glsl_phong_lighting.frag", "main", - naga::ShaderStage::Fragment { - early_depth_test: None, - }, + naga::ShaderStage::Fragment, ); naga::proc::Validator::new().validate(&module).unwrap(); @@ -176,9 +198,7 @@ fn convert_phong_lighting() { // let module = load_glsl( // "glsl_constant_expression.vert", // "main", -// naga::ShaderStage::Fragment { -// early_depth_test: None, -// }, +// naga::ShaderStage::Fragment, // ); // naga::proc::Validator::new().validate(&module).unwrap(); // }