From 68bbf00249451343046b0a9f84ea72f3effdcc12 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Capucho?= Date: Sun, 31 Jan 2021 16:57:01 +0000 Subject: [PATCH] Added support for initalizers and constants --- src/front/glsl/ast.rs | 31 ++++++- src/front/glsl/functions.rs | 10 +++ src/front/glsl/lex.rs | 1 + src/front/glsl/parser.rs | 91 +++++++++++++------ src/front/glsl/parser_tests.rs | 56 ++++++++++++ src/front/glsl/variables.rs | 3 + src/lib.rs | 4 +- src/proc/constants.rs | 157 +++++++++++++++++++++++++++++++++ src/proc/mod.rs | 2 + 9 files changed, 321 insertions(+), 34 deletions(-) create mode 100644 src/proc/constants.rs diff --git a/src/front/glsl/ast.rs b/src/front/glsl/ast.rs index 647c3a6272..4ebc2da023 100644 --- a/src/front/glsl/ast.rs +++ b/src/front/glsl/ast.rs @@ -1,7 +1,7 @@ use super::error::ErrorKind; use crate::{ - proc::{ResolveContext, Typifier}, - Arena, BinaryOperator, Binding, Expression, FastHashMap, Function, FunctionArgument, + proc::{ConstantSolver, ResolveContext, Typifier}, + Arena, BinaryOperator, Binding, Constant, Expression, FastHashMap, Function, FunctionArgument, GlobalVariable, Handle, Interpolation, LocalVariable, Module, ShaderStage, Statement, StorageClass, Type, }; @@ -15,6 +15,7 @@ pub struct Program { pub lookup_function: FastHashMap>, pub lookup_type: FastHashMap>, pub lookup_global_variables: FastHashMap>, + pub lookup_constants: FastHashMap>, pub context: Context, pub module: Module, } @@ -29,12 +30,14 @@ impl Program { lookup_function: FastHashMap::default(), lookup_type: FastHashMap::default(), lookup_global_variables: FastHashMap::default(), + lookup_constants: FastHashMap::default(), context: Context { expressions: Arena::::new(), local_variables: Arena::::new(), arguments: Vec::new(), scopes: vec![FastHashMap::default()], lookup_global_var_exps: FastHashMap::default(), + lookup_constant_exps: FastHashMap::default(), typifier: Typifier::new(), }, module: Module::generate_empty(), @@ -76,6 +79,21 @@ impl Program { Ok(()) => Ok(self.context.typifier.get(handle, &self.module.types)), } } + + pub fn solve_constant( + &mut self, + root: Handle, + ) -> Result, ErrorKind> { + let mut solver = ConstantSolver { + types: &self.module.types, + expressions: &self.context.expressions, + constants: &mut self.module.constants, + }; + + solver + .solve(root) + .map_err(|_| ErrorKind::SemanticError("Can't solve constant".into())) + } } #[derive(Debug)] @@ -91,6 +109,7 @@ pub struct Context { //TODO: Find less allocation heavy representation pub scopes: Vec>>, pub lookup_global_var_exps: FastHashMap>, + pub lookup_constant_exps: FastHashMap>, pub typifier: Typifier, } @@ -154,7 +173,7 @@ impl ExpressionRule { #[derive(Debug)] pub enum TypeQualifier { - StorageClass(StorageClass), + StorageQualifier(StorageQualifier), Binding(Binding), Interpolation(Interpolation), } @@ -177,3 +196,9 @@ pub struct FunctionCall { pub kind: FunctionCallKind, pub args: Vec, } + +#[derive(Debug, Clone, Copy)] +pub enum StorageQualifier { + StorageClass(StorageClass), + Const, +} diff --git a/src/front/glsl/functions.rs b/src/front/glsl/functions.rs index 42d2721b1c..7ec4b4e11e 100644 --- a/src/front/glsl/functions.rs +++ b/src/front/glsl/functions.rs @@ -179,6 +179,16 @@ impl Program { } } } + + for (handle, constant) in self.module.constants.iter() { + if let Some(name) = constant.name.as_ref() { + let expr = self + .context + .expressions + .append(Expression::Constant(handle)); + self.context.lookup_constant_exps.insert(name.clone(), expr); + } + } } pub fn function_definition(&mut self, mut f: Function, mut block: Block) -> Function { diff --git a/src/front/glsl/lex.rs b/src/front/glsl/lex.rs index e5412c8a8d..d3bb1645bb 100644 --- a/src/front/glsl/lex.rs +++ b/src/front/glsl/lex.rs @@ -182,6 +182,7 @@ impl<'a> Lexer<'a> { "for" => Some(Token::For(meta)), // types "void" => Some(Token::Void(meta)), + "const" => Some(Token::Const(meta)), word => { let token = match parse_type(word) { Some(t) => Token::TypeName((meta, t)), diff --git a/src/front/glsl/parser.rs b/src/front/glsl/parser.rs index 7985313c39..b9c889cd76 100644 --- a/src/front/glsl/parser.rs +++ b/src/front/glsl/parser.rs @@ -114,7 +114,7 @@ pomelo! { %type layout_qualifier_id (String, u32); %type type_qualifier Vec; %type single_type_qualifier TypeQualifier; - %type storage_qualifier StorageClass; + %type storage_qualifier StorageQualifier; %type interpolation_qualifier Interpolation; %type Interpolation Interpolation; @@ -627,7 +627,7 @@ pomelo! { } single_type_qualifier ::= storage_qualifier(s) { - TypeQualifier::StorageClass(s) + TypeQualifier::StorageQualifier(s) } single_type_qualifier ::= layout_qualifier(l) { TypeQualifier::Binding(l) @@ -639,19 +639,21 @@ pomelo! { // single_type_qualifier ::= invariant_qualifier; // single_type_qualifier ::= precise_qualifier; - // storage_qualifier ::= Const + storage_qualifier ::= Const { + StorageQualifier::Const + } // storage_qualifier ::= InOut; storage_qualifier ::= In { - StorageClass::Input + StorageQualifier::StorageClass(StorageClass::Input) } storage_qualifier ::= Out { - StorageClass::Output + StorageQualifier::StorageClass(StorageClass::Output) } // storage_qualifier ::= Centroid; // storage_qualifier ::= Patch; // storage_qualifier ::= Sample; storage_qualifier ::= Uniform { - StorageClass::Uniform + StorageQualifier::StorageClass(StorageClass::Uniform) } //TODO: other storage qualifiers @@ -1062,32 +1064,63 @@ pomelo! { } external_declaration ::= declaration(d) { if let Some(d) = d { - let class = d.type_qualifiers.iter().find_map(|tq| { - if let TypeQualifier::StorageClass(sc) = tq { Some(*sc) } else { None } - }).unwrap_or(StorageClass::Private); + // TODO: handle multiple storage qualifiers + let storage = d.type_qualifiers.iter().find_map(|tq| { + if let TypeQualifier::StorageQualifier(sc) = tq { Some(*sc) } else { None } + }).unwrap_or(StorageQualifier::StorageClass(StorageClass::Private)); - let binding = d.type_qualifiers.iter().find_map(|tq| { - if let TypeQualifier::Binding(b) = tq { Some(b.clone()) } else { None } - }); + match storage { + StorageQualifier::StorageClass(class) => { + // TODO: Check that the storage qualifiers allow for the bindings + let binding = d.type_qualifiers.iter().find_map(|tq| { + if let TypeQualifier::Binding(b) = tq { Some(b.clone()) } else { None } + }); - let interpolation = d.type_qualifiers.iter().find_map(|tq| { - if let TypeQualifier::Interpolation(i) = tq { Some(*i) } else { None } - }); + let interpolation = d.type_qualifiers.iter().find_map(|tq| { + if let TypeQualifier::Interpolation(i) = tq { Some(*i) } else { None } + }); - for (id, initializer) in d.ids_initializers { - let h = extra.module.global_variables.fetch_or_append( - GlobalVariable { - name: id.clone(), - class, - binding: binding.clone(), - ty: d.ty, - init: None, - interpolation, - storage_access: StorageAccess::empty(), //TODO - }, - ); - if let Some(id) = id { - extra.lookup_global_variables.insert(id, h); + for (id, initializer) in d.ids_initializers { + let init = initializer.map(|init| extra.solve_constant(init.expression)).transpose()?; + + let h = extra.module.global_variables.fetch_or_append( + GlobalVariable { + name: id.clone(), + class, + binding: binding.clone(), + ty: d.ty, + init, + interpolation, + storage_access: StorageAccess::empty(), //TODO + }, + ); + if let Some(id) = id { + extra.lookup_global_variables.insert(id, h); + } + } + } + StorageQualifier::Const => { + for (id, initializer) in d.ids_initializers { + if let Some(init) = initializer { + let constant = extra.solve_constant(init.expression)?; + let inner = extra.module.constants[constant].inner.clone(); + + let h = extra.module.constants.fetch_or_append( + Constant { + name: id.clone(), + specialization: None, // TODO + inner + }, + ); + if let Some(id) = id { + extra.lookup_constants.insert(id.clone(), h); + let expr = extra.context.expressions.append(Expression::Constant(h)); + extra.context.lookup_constant_exps.insert(id, expr); + } + } else { + return Err(ErrorKind::SemanticError("Constants must have an initalizer".into())) + } + } } } } diff --git a/src/front/glsl/parser_tests.rs b/src/front/glsl/parser_tests.rs index b586c8bbc1..1e0a8e719b 100644 --- a/src/front/glsl/parser_tests.rs +++ b/src/front/glsl/parser_tests.rs @@ -231,3 +231,59 @@ fn functions() { ) .unwrap(); } + +#[test] +fn constants() { + use crate::{Constant, ConstantInner, ScalarValue}; + + let program = parse_program( + r#" + # version 450 + const float a = 1.0; + float global = a; + const flat float b = a; + "#, + ShaderStage::Vertex, + ) + .unwrap(); + + let mut constants = program.module.constants.iter(); + + assert_eq!( + constants.next().unwrap().1, + &Constant { + name: None, + specialization: None, + inner: ConstantInner::Scalar { + width: 4, + value: ScalarValue::Float(1.0) + } + } + ); + + assert_eq!( + constants.next().unwrap().1, + &Constant { + name: Some(String::from("a")), + specialization: None, + inner: ConstantInner::Scalar { + width: 4, + value: ScalarValue::Float(1.0) + } + } + ); + + assert_eq!( + constants.next().unwrap().1, + &Constant { + name: Some(String::from("b")), + specialization: None, + inner: ConstantInner::Scalar { + width: 4, + value: ScalarValue::Float(1.0) + } + } + ); + + assert!(constants.next().is_none()); +} diff --git a/src/front/glsl/variables.rs b/src/front/glsl/variables.rs index 70022a2ac5..78b522ec55 100644 --- a/src/front/glsl/variables.rs +++ b/src/front/glsl/variables.rs @@ -15,6 +15,9 @@ impl Program { if let Some(global_var) = self.context.lookup_global_var_exps.get(name) { return Ok(Some(*global_var)); } + if let Some(constant) = self.context.lookup_constant_exps.get(name) { + return Ok(Some(*constant)); + } match name { "gl_Position" => { #[cfg(feature = "glsl-validate")] diff --git a/src/lib.rs b/src/lib.rs index 297454fe97..9db837dd39 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -401,7 +401,7 @@ pub struct Constant { } /// A literal scalar value, used in constants. -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum ScalarValue { @@ -412,7 +412,7 @@ pub enum ScalarValue { } /// Additional information, dependendent on the kind of constant. -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] pub enum ConstantInner { diff --git a/src/proc/constants.rs b/src/proc/constants.rs new file mode 100644 index 0000000000..ec705515ba --- /dev/null +++ b/src/proc/constants.rs @@ -0,0 +1,157 @@ +use crate::{ + arena::{Arena, Handle}, + ArraySize, Constant, ConstantInner, Expression, ScalarValue, Type, +}; + +#[derive(Debug)] +pub struct ConstantSolver<'a> { + pub types: &'a Arena, + pub expressions: &'a Arena, + pub constants: &'a mut Arena, +} + +#[derive(Clone, Debug, PartialEq, thiserror::Error)] +pub enum ConstantSolvingError { + #[error("Constants cannot access function arguments")] + FunctionArg, + #[error("Constants cannot access global variables")] + GlobalVariable, + #[error("Constants cannot access local variables")] + LocalVariable, + #[error("Cannot get the array length of a non array type")] + InvalidArrayLengthArg, + #[error("Constants cannot get the array length of a dynamically sized array")] + ArrayLengthDynamic, + #[error("Constants cannot call functions")] + Call, + #[error("Constants don't support relational functions")] + Relational, + #[error("Constants don't support derivative functions")] + Derivative, + #[error("Constants don't support select expressions")] + Select, + #[error("Constants don't support load expressions")] + Load, + #[error("Constants don't support image expressions")] + ImageExpression, + #[error("Cannot access the type")] + InvalidAccessBase, + #[error("Cannot access at the index")] + InvalidAccessIndex, + #[error("Cannot access with index of type")] + InvalidAccessIndexTy, +} + +impl<'a> ConstantSolver<'a> { + pub fn solve( + &mut self, + expr: Handle, + ) -> Result, ConstantSolvingError> { + match self.expressions[expr] { + Expression::Constant(constant) => Ok(constant), + Expression::AccessIndex { base, index } => self.access(base, index as usize), + Expression::Access { base, index } => { + let index = self.solve(index)?; + + self.access(base, self.constant_index(index)?) + } + Expression::Compose { ty, ref components } => { + let components = components + .iter() + .map(|c| self.solve(*c)) + .collect::>()?; + + Ok(self.constants.fetch_or_append(Constant { + name: None, + specialization: None, + inner: ConstantInner::Composite { ty, components }, + })) + } + Expression::Unary { .. } => todo!(), + Expression::Binary { .. } => todo!(), + Expression::Math { .. } => todo!(), + Expression::As { .. } => todo!(), + Expression::ArrayLength(expr) => { + let array = self.solve(expr)?; + + match self.constants[array].inner { + crate::ConstantInner::Scalar { .. } => { + Err(ConstantSolvingError::InvalidArrayLengthArg) + } + crate::ConstantInner::Composite { ty, .. } => match self.types[ty].inner { + crate::TypeInner::Array { size, .. } => match size { + crate::ArraySize::Constant(constant) => Ok(constant), + crate::ArraySize::Dynamic => { + Err(ConstantSolvingError::ArrayLengthDynamic) + } + }, + _ => Err(ConstantSolvingError::InvalidArrayLengthArg), + }, + } + } + + Expression::Load { .. } => Err(ConstantSolvingError::Load), + Expression::Select { .. } => Err(ConstantSolvingError::Select), + Expression::LocalVariable(_) => Err(ConstantSolvingError::LocalVariable), + Expression::Derivative { .. } => Err(ConstantSolvingError::Derivative), + Expression::Relational { .. } => Err(ConstantSolvingError::Relational), + Expression::Call { .. } => Err(ConstantSolvingError::Call), + Expression::FunctionArgument(_) => Err(ConstantSolvingError::FunctionArg), + Expression::GlobalVariable(_) => Err(ConstantSolvingError::GlobalVariable), + Expression::ImageSample { .. } => Err(ConstantSolvingError::ImageExpression), + Expression::ImageLoad { .. } => Err(ConstantSolvingError::ImageExpression), + } + } + + fn access( + &mut self, + base: Handle, + index: usize, + ) -> Result, ConstantSolvingError> { + let base = self.solve(base)?; + + match self.constants[base].inner { + crate::ConstantInner::Scalar { .. } => Err(ConstantSolvingError::InvalidAccessBase), + crate::ConstantInner::Composite { ty, ref components } => match self.types[ty].inner { + crate::TypeInner::Vector { size, .. } => { + if size as usize <= index { + Err(ConstantSolvingError::InvalidAccessIndex) + } else { + Ok(components[index]) + } + } + crate::TypeInner::Matrix { .. } => todo!(), + crate::TypeInner::Array { size, .. } => match size { + ArraySize::Constant(constant) => { + let size = self.constant_index(constant)?; + + if size <= index { + Err(ConstantSolvingError::InvalidAccessIndex) + } else { + Ok(components[index]) + } + } + ArraySize::Dynamic => Err(ConstantSolvingError::ArrayLengthDynamic), + }, + crate::TypeInner::Struct { ref members, .. } => { + if members.len() <= index { + Err(ConstantSolvingError::InvalidAccessIndex) + } else { + Ok(components[index]) + } + } + _ => Err(ConstantSolvingError::InvalidAccessBase), + }, + } + } + + fn constant_index(&self, constant: Handle) -> Result { + match self.constants[constant].inner { + ConstantInner::Scalar { + value: ScalarValue::Uint(index), + .. + } => Ok(index as usize), + _ => Err(ConstantSolvingError::InvalidAccessIndexTy), + } + } +} diff --git a/src/proc/mod.rs b/src/proc/mod.rs index 5dec4285bf..7aa2ca0568 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -2,6 +2,7 @@ #[cfg(feature = "petgraph")] mod call_graph; +mod constants; mod interface; mod layouter; mod namer; @@ -11,6 +12,7 @@ mod validator; #[cfg(feature = "petgraph")] pub use call_graph::{CallGraph, CallGraphBuilder}; +pub use constants::{ConstantSolver, ConstantSolvingError}; pub use interface::{Interface, Visitor}; pub use layouter::{Alignment, Layouter}; pub use namer::{EntryPointIndex, NameKey, Namer};