From 41b3865e5bb41b846e206de1a7bd57fbc1cc8f20 Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Tue, 3 Nov 2020 23:26:25 -0500 Subject: [PATCH] Make variable initializers to be const --- src/back/glsl.rs | 2 +- src/back/msl.rs | 8 +++- src/back/spv/writer.rs | 89 +++++++++++++++++-------------------- src/front/glsl/parser.rs | 41 ++++++++--------- src/front/glsl/variables.rs | 2 + src/front/spv/mod.rs | 15 ++++--- src/front/wgsl/mod.rs | 64 +++++++++++++------------- src/lib.rs | 7 +-- src/proc/interface.rs | 9 +--- test-data/simple/module.ron | 4 +- 10 files changed, 119 insertions(+), 122 deletions(-) diff --git a/src/back/glsl.rs b/src/back/glsl.rs index 29e677eaea..7c2d84c57b 100644 --- a/src/back/glsl.rs +++ b/src/back/glsl.rs @@ -706,7 +706,7 @@ pub fn write<'a>( write!( &mut buf, " = {}", - write_expression(&func.expressions[init], module, &mut builder, &mut manager)? + write_constant(&module.constants[init], module, &mut builder, &mut manager)? )?; } writeln!(&mut buf, ";")?; diff --git a/src/back/msl.rs b/src/back/msl.rs index cade25218e..6aa52d062a 100644 --- a/src/back/msl.rs +++ b/src/back/msl.rs @@ -1025,7 +1025,7 @@ impl Writer { )?; if let Some(value) = local.init { write!(self.out, " = ")?; - self.put_expression(value, fun, module)?; + self.put_constant(value, module)?; } writeln!(self.out, ";")?; } @@ -1235,6 +1235,10 @@ impl Writer { write!(self.out, "\t")?; tyvar.try_fmt(&mut self.out)?; resolved.try_fmt_decorated(&mut self.out, separator)?; + if let Some(value) = var.init { + write!(self.out, " = ")?; + self.put_constant(value, module)?; + } writeln!(self.out)?; } writeln!(self.out, ") {{")?; @@ -1255,7 +1259,7 @@ impl Writer { )?; if let Some(value) = local.init { write!(self.out, " = ")?; - self.put_expression(value, fun, module)?; + self.put_constant(value, module)?; } writeln!(self.out, ";")?; } diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index 159869bec6..32b5e53d92 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -195,15 +195,13 @@ impl Writer { fn get_global_variable_id( &mut self, - arena: &crate::Arena, - global_arena: &crate::Arena, + ir_module: &crate::Module, handle: crate::Handle, ) -> Word { match self.lookup_global_variable.entry(handle) { Entry::Occupied(e) => *e.get(), _ => { - let global_variable = &global_arena[handle]; - let (instruction, id) = self.write_global_variable(arena, global_variable, handle); + let (instruction, id) = self.write_global_variable(ir_module, handle); instruction.to_words(&mut self.logical_layout.declarations); id } @@ -277,15 +275,9 @@ impl Writer { for (_, variable) in ir_function.local_variables.iter() { let id = self.generate_id(); - let init_word = match variable.init { - Some(exp) => match &ir_function.expressions[exp] { - crate::Expression::Constant(handle) => { - Some(self.get_constant_id(*handle, ir_module)) - } - _ => unreachable!(), - }, - None => None, - }; + let init_word = variable + .init + .map(|constant| self.get_constant_id(constant, ir_module)); let pointer_id = self.get_pointer_id(&ir_module.types, variable.ty, spirv::StorageClass::Function); @@ -372,11 +364,7 @@ impl Writer { .zip(&entry_point.function.global_usage) { if usage.contains(crate::GlobalUse::STORE) || usage.contains(crate::GlobalUse::LOAD) { - let id = self.get_global_variable_id( - &ir_module.types, - &ir_module.global_variables, - handle, - ); + let id = self.get_global_variable_id(ir_module, handle); interface_ids.push(id); } } @@ -681,17 +669,21 @@ impl Writer { fn write_global_variable( &mut self, - arena: &crate::Arena, - global_variable: &crate::GlobalVariable, + ir_module: &crate::Module, handle: crate::Handle, ) -> (Instruction, Word) { + let global_variable = &ir_module.global_variables[handle]; let id = self.generate_id(); let class = self.parse_to_spirv_storage_class(global_variable.class); self.try_add_capabilities(class.required_capabilities()); - let pointer_id = self.get_pointer_id(arena, global_variable.ty, class); - let instruction = super::instructions::instruction_variable(pointer_id, id, class, None); + let init_word = global_variable + .init + .map(|constant| self.get_constant_id(constant, ir_module)); + let pointer_id = self.get_pointer_id(&ir_module.types, global_variable.ty, class); + let instruction = + super::instructions::instruction_variable(pointer_id, id, class, init_word); if self.writer_flags.contains(WriterFlags::DEBUG) { if let Some(ref name) = global_variable.name { @@ -825,23 +817,19 @@ impl Writer { block: &mut Block, function: &mut Function, ) -> Result { - match expression { + match *expression { crate::Expression::GlobalVariable(handle) => { - let var = &ir_module.global_variables[*handle]; - let id = self.get_global_variable_id( - &ir_module.types, - &ir_module.global_variables, - *handle, - ); + let var = &ir_module.global_variables[handle]; + let id = self.get_global_variable_id(ir_module, handle); Ok(Some((id, Some(var.ty)))) } crate::Expression::Constant(handle) => { - let var = &ir_module.constants[*handle]; - let id = self.get_constant_id(*handle, ir_module); + let var = &ir_module.constants[handle]; + let id = self.get_constant_id(handle, ir_module); Ok(Some((id, Some(var.ty)))) } - crate::Expression::Compose { ty, components } => { - let base_type_id = self.get_type_id(&ir_module.types, LookupType::Handle(*ty)); + crate::Expression::Compose { ty, ref components } => { + let base_type_id = self.get_type_id(&ir_module.types, LookupType::Handle(ty)); let mut constituent_ids = Vec::with_capacity(components.len()); for component in components { @@ -853,7 +841,7 @@ impl Writer { } let constituent_ids_slice = constituent_ids.as_slice(); - let id = match ir_module.types[*ty].inner { + let id = match ir_module.types[ty].inner { crate::TypeInner::Vector { .. } => { self.write_composite_construct(base_type_id, constituent_ids_slice, block) } @@ -893,14 +881,14 @@ impl Writer { _ => unreachable!(), }; - Ok(Some((id, Some(*ty)))) + Ok(Some((id, Some(ty)))) } crate::Expression::Binary { op, left, right } => { match op { crate::BinaryOperator::Multiply => { let id = self.generate_id(); - let left_expression = &ir_function.expressions[*left]; - let right_expression = &ir_function.expressions[*right]; + let left_expression = &ir_function.expressions[left]; + let right_expression = &ir_function.expressions[right]; let (left_id, left_ty) = self .write_expression( ir_module, @@ -1049,7 +1037,7 @@ impl Writer { } } crate::Expression::LocalVariable(variable) => { - let var = &ir_function.local_variables[*variable]; + let var = &ir_function.local_variables[variable]; function .variables .iter() @@ -1058,19 +1046,22 @@ impl Writer { .ok_or_else(|| Error::UnknownLocalVariable(var.clone())) } crate::Expression::FunctionParameter(index) => { - let handle = ir_function.parameter_types.get(*index as usize).unwrap(); + let handle = ir_function.parameter_types.get(index as usize).unwrap(); let type_id = self.get_type_id(&ir_module.types, LookupType::Handle(*handle)); let load_id = self.generate_id(); block.body.push(super::instructions::instruction_load( type_id, load_id, - function.parameters[*index as usize].result_id.unwrap(), + function.parameters[index as usize].result_id.unwrap(), None, )); Ok(Some((load_id, Some(*handle)))) } - crate::Expression::Call { origin, arguments } => match origin { + crate::Expression::Call { + ref origin, + ref arguments, + } => match origin { crate::FunctionOrigin::Local(local_function) => { let origin_function = &ir_module.functions[*local_function]; let id = self.generate_id(); @@ -1119,7 +1110,7 @@ impl Writer { .push(super::instructions::instruction_function_call( return_type_id, id, - *self.lookup_function.get(local_function).unwrap(), + *self.lookup_function.get(&local_function).unwrap(), argument_ids.as_slice(), )); Ok(Some((id, None))) @@ -1139,7 +1130,7 @@ impl Writer { .write_expression( ir_module, ir_function, - &ir_function.expressions[*expr], + &ir_function.expressions[expr], block, function, )? @@ -1153,10 +1144,10 @@ impl Writer { } => { let kind_type_id = self.get_type_id( &ir_module.types, - LookupType::Local(LocalType::Scalar { kind: *kind, width }), + LookupType::Local(LocalType::Scalar { kind, width }), ); - if *convert { + if convert { super::instructions::instruction_bit_cast(kind_type_id, id, expr_id) } else { match (expr_kind, kind) { @@ -1207,7 +1198,7 @@ impl Writer { depth_ref: _, } => { // image - let image_expression = &ir_function.expressions[*image]; + let image_expression = &ir_function.expressions[image]; let (image_id, image_ty) = self .write_expression(ir_module, ir_function, image_expression, block, function)? .ok_or(Error::EmptyValue)?; @@ -1237,7 +1228,7 @@ impl Writer { ); // sampler - let sampler_expression = &ir_function.expressions[*sampler]; + let sampler_expression = &ir_function.expressions[sampler]; let (sampler_id, sampler_ty) = self .write_expression(ir_module, ir_function, sampler_expression, block, function)? .ok_or(Error::EmptyValue)?; @@ -1259,7 +1250,7 @@ impl Writer { }; // coordinate - let coordinate_expression = &ir_function.expressions[*coordinate]; + let coordinate_expression = &ir_function.expressions[coordinate]; let (coordinate_id, coordinate_ty) = self .write_expression( ir_module, @@ -1469,7 +1460,7 @@ impl Writer { } for (handle, _) in ir_module.global_variables.iter() { - self.get_global_variable_id(&ir_module.types, &ir_module.global_variables, handle); + self.get_global_variable_id(ir_module, handle); } for (handle, _) in ir_module.constants.iter() { diff --git a/src/front/glsl/parser.rs b/src/front/glsl/parser.rs index 5e6a16d341..69827f9637 100644 --- a/src/front/glsl/parser.rs +++ b/src/front/glsl/parser.rs @@ -760,18 +760,33 @@ pomelo! { return Err(ErrorKind::VariableAlreadyDeclared(id)) } } + let mut init_exp: Option> = None; let localVar = extra.context.local_variables.append( LocalVariable { name: Some(id.clone()), ty: d.ty, init: initializer.map(|i| { statements.extend(i.statements); - i.expression - }), + if let Expression::Constant(constant) = extra.context.expressions[i.expression] { + Some(constant) + } else { + init_exp = Some(i.expression); + None + } + }).flatten(), } ); let exp = extra.context.expressions.append(Expression::LocalVariable(localVar)); extra.context.add_local_var(id, exp); + + if let Some(value) = init_exp { + statements.push( + Statement::Store { + pointer: exp, + value, + } + ); + } } match statements.len() { 1 => statements.remove(0), @@ -1066,33 +1081,13 @@ pomelo! { class, binding: binding.clone(), ty: d.ty, + init: None, interpolation, storage_access: StorageAccess::empty(), //TODO }, ); if let Some(id) = id { extra.lookup_global_variables.insert(id, h); - } else { - // variables in interface blocks without an instance name are in the global namespace - // https://www.khronos.org/opengl/wiki/Interface_Block_(GLSL) - if let TypeInner::Struct { members } = &extra.module.types[d.ty].inner { - for m in members { - if let Some(name) = &m.name { - let h = extra - .module - .global_variables - .fetch_or_append(GlobalVariable { - name: Some(name.into()), - class, - binding: binding.clone(), - ty: m.ty, - interpolation, - storage_access: StorageAccess::empty(), // TODO - }); - extra.lookup_global_variables.insert(name.into(), h); - } - } - } } } } diff --git a/src/front/glsl/variables.rs b/src/front/glsl/variables.rs index c4d37b7bb3..9f76335685 100644 --- a/src/front/glsl/variables.rs +++ b/src/front/glsl/variables.rs @@ -38,6 +38,7 @@ impl Program { width: 4, }, }), + init: None, interpolation: None, storage_access: StorageAccess::empty(), }); @@ -72,6 +73,7 @@ impl Program { width: 4, }, }), + init: None, interpolation: None, storage_access: StorageAccess::empty(), }); diff --git a/src/front/spv/mod.rs b/src/front/spv/mod.rs index b59fd239b4..f5aca1cf49 100644 --- a/src/front/spv/mod.rs +++ b/src/front/spv/mod.rs @@ -554,8 +554,8 @@ impl> Parser { let init = if inst.wc > 4 { inst.expect(5)?; let init_id = self.next()?; - let lexp = self.lookup_expression.lookup(init_id)?; - Some(lexp.handle) + let lconst = self.lookup_constant.lookup(init_id)?; + Some(lconst.handle) } else { None }; @@ -2271,10 +2271,14 @@ impl> Parser { let type_id = self.next()?; let id = self.next()?; let storage_class = self.next()?; - if inst.wc != 4 { + let init = if inst.wc > 4 { inst.expect(5)?; - let _init = self.next()?; //TODO - } + let init_id = self.next()?; + let lconst = self.lookup_constant.lookup(init_id)?; + Some(lconst.handle) + } else { + None + }; let lookup_type = self.lookup_type.lookup(type_id)?; let dec = self .future_decor @@ -2338,6 +2342,7 @@ impl> Parser { class, binding, ty: lookup_type.handle, + init, interpolation: dec.interpolation, storage_access, }; diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 05c2749946..ee84b040bd 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -36,8 +36,6 @@ pub enum Token<'a> { pub enum Error<'a> { #[error("unexpected token: {0:?}")] Unexpected(Token<'a>), - #[error("constant {0:?} doesn't match its type {1:?}")] - UnexpectedConstantType(crate::ConstantInner, Handle), #[error("unable to parse `{0}` as integer: {1}")] BadInteger(&'a str, std::num::ParseIntError), #[error("unable to parse `{1}` as float: {1}")] @@ -264,6 +262,7 @@ struct ParsedVariable<'a> { class: Option, ty: Handle, access: crate::StorageAccess, + init: Option>, } #[derive(Clone, Debug, Error)] @@ -374,9 +373,10 @@ impl Parser { fn parse_const_expression<'a>( &mut self, lexer: &mut Lexer<'a>, + self_ty: Handle, type_arena: &mut Arena, const_arena: &mut Arena, - ) -> Result> { + ) -> Result, Error<'a>> { self.scopes.push(Scope::ConstantExpr); let inner = match lexer.peek() { Token::Word("true") => { @@ -405,19 +405,21 @@ impl Parser { composite_ty, components.len(), )?; - let inner = self.parse_const_expression(lexer, type_arena, const_arena)?; - components.push(const_arena.fetch_or_append(crate::Constant { - name: None, - specialization: None, - inner, - ty, - })); + let component = + self.parse_const_expression(lexer, ty, type_arena, const_arena)?; + components.push(component); } crate::ConstantInner::Composite(components) } }; + let handle = const_arena.fetch_or_append(crate::Constant { + name: None, + specialization: None, + inner, + ty: self_ty, + }); self.scopes.pop(); - Ok(inner) + Ok(handle) } fn parse_primary_expression<'a>( @@ -943,10 +945,12 @@ impl Parser { } _ => crate::StorageAccess::empty(), }; - if lexer.skip(Token::Operation('=')) { - let _inner = self.parse_const_expression(lexer, type_arena, const_arena)?; - //TODO - } + let init = if lexer.skip(Token::Operation('=')) { + let handle = self.parse_const_expression(lexer, ty, type_arena, const_arena)?; + Some(handle) + } else { + None + }; lexer.expect(Token::Separator(';'))?; self.scopes.pop(); Ok(ParsedVariable { @@ -954,6 +958,7 @@ impl Parser { class, ty, access, + init, }) } @@ -1373,15 +1378,16 @@ impl Parser { "var" => { enum Init { Empty, - Uniform(Handle), + Constant(Handle), Variable(Handle), } let (name, ty) = self.parse_variable_ident_decl(lexer, context.types)?; let init = if lexer.skip(Token::Operation('=')) { let value = self.parse_general_expression(lexer, context.as_expression())?; - if let crate::Expression::Constant(_) = context.expressions[value] { - Init::Uniform(value) + if let crate::Expression::Constant(handle) = context.expressions[value] + { + Init::Constant(handle) } else { Init::Variable(value) } @@ -1393,7 +1399,7 @@ impl Parser { name: Some(name.to_owned()), ty, init: match init { - Init::Uniform(value) => Some(value), + Init::Constant(value) => Some(value), _ => None, }, }); @@ -1692,18 +1698,13 @@ impl Parser { Token::Word("const") => { let (name, ty) = self.parse_variable_ident_decl(lexer, &mut module.types)?; lexer.expect(Token::Operation('='))?; - let inner = - self.parse_const_expression(lexer, &mut module.types, &mut module.constants)?; - lexer.expect(Token::Separator(';'))?; - if !crate::proc::check_constant_type(&inner, &module.types[ty].inner) { - return Err(Error::UnexpectedConstantType(inner, ty)); - } - let const_handle = module.constants.append(crate::Constant { - name: Some(name.to_owned()), - specialization: None, - inner, + let const_handle = self.parse_const_expression( + lexer, ty, - }); + &mut module.types, + &mut module.constants, + )?; + lexer.expect(Token::Separator(';'))?; lookup_global_expression.insert(name, crate::Expression::Constant(const_handle)); } Token::Word("var") => { @@ -1725,6 +1726,7 @@ impl Parser { class, binding: binding.take(), ty: pvar.ty, + init: pvar.init, interpolation, storage_access: pvar.access, }); @@ -1813,5 +1815,5 @@ pub fn parse_str(source: &str) -> Result { #[test] fn parse_types() { assert!(parse_str("const a : i32 = 2;").is_ok()); - assert!(parse_str("const a : i32 = 2.0;").is_err()); + assert!(parse_str("const a : x32 = 2;").is_err()); } diff --git a/src/lib.rs b/src/lib.rs index c8c1b0107b..ceec1f2b81 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -406,8 +406,7 @@ pub struct Constant { } /// Additional information, dependendent on the kind of constant. -// Clone is used only for error reporting and is not intended for end users -#[derive(Clone, Debug, PartialEq)] +#[derive(Debug, PartialEq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum ConstantInner { @@ -456,6 +455,8 @@ pub struct GlobalVariable { pub binding: Option, /// The type of this variable. pub ty: Handle, + /// Initial value for this variable. + pub init: Option>, /// The interpolation qualifier, if any. /// If the this `GlobalVariable` is a vertex output /// or fragment input, `None` corresponds to the @@ -475,7 +476,7 @@ pub struct LocalVariable { /// The type of this variable. pub ty: Handle, /// Initial value for this variable. - pub init: Option>, + pub init: Option>, } /// Operation that can be applied on a single value. diff --git a/src/proc/interface.rs b/src/proc/interface.rs index ca20d1958b..03cbf3e403 100644 --- a/src/proc/interface.rs +++ b/src/proc/interface.rs @@ -37,13 +37,7 @@ where self.traverse_expr(comp); } } - E::FunctionParameter(_) | E::GlobalVariable(_) => {} - E::LocalVariable(var) => { - let var = &self.local_variables[var]; - if let Some(init) = var.init { - self.traverse_expr(init); - } - } + E::FunctionParameter(_) | E::GlobalVariable(_) | E::LocalVariable(_) => {} E::Load { pointer } => { self.traverse_expr(pointer); } @@ -229,6 +223,7 @@ mod tests { class: StorageClass::Uniform, binding: None, ty: Handle::new(std::num::NonZeroU32::new(1).unwrap()), + init: None, interpolation: None, storage_access: StorageAccess::empty(), }; diff --git a/test-data/simple/module.ron b/test-data/simple/module.ron index c3c34fc6b8..bcc93f0203 100644 --- a/test-data/simple/module.ron +++ b/test-data/simple/module.ron @@ -48,6 +48,7 @@ class: Input, binding: Some(Location(0)), ty: 1, + init: None, interpolation: None, storage_access: ( bits: 0, @@ -58,6 +59,7 @@ class: Output, binding: Some(Location(0)), ty: 2, + init: None, interpolation: None, storage_access: ( bits: 0, @@ -85,7 +87,7 @@ ( name: Some("w"), ty: 3, - init: Some(3), + init: Some(1), ), ], expressions: [