Added support for initalizers and constants

This commit is contained in:
João Capucho
2021-01-31 16:57:01 +00:00
committed by Dzmitry Malyshau
parent e05baa2889
commit 68bbf00249
9 changed files with 321 additions and 34 deletions

View File

@@ -1,7 +1,7 @@
use super::error::ErrorKind;
use crate::{
proc::{ResolveContext, Typifier},
Arena, BinaryOperator, Binding, Expression, FastHashMap, Function, FunctionArgument,
proc::{ConstantSolver, ResolveContext, Typifier},
Arena, BinaryOperator, Binding, Constant, Expression, FastHashMap, Function, FunctionArgument,
GlobalVariable, Handle, Interpolation, LocalVariable, Module, ShaderStage, Statement,
StorageClass, Type,
};
@@ -15,6 +15,7 @@ pub struct Program {
pub lookup_function: FastHashMap<String, Handle<Function>>,
pub lookup_type: FastHashMap<String, Handle<Type>>,
pub lookup_global_variables: FastHashMap<String, Handle<GlobalVariable>>,
pub lookup_constants: FastHashMap<String, Handle<Constant>>,
pub context: Context,
pub module: Module,
}
@@ -29,12 +30,14 @@ impl Program {
lookup_function: FastHashMap::default(),
lookup_type: FastHashMap::default(),
lookup_global_variables: FastHashMap::default(),
lookup_constants: FastHashMap::default(),
context: Context {
expressions: Arena::<Expression>::new(),
local_variables: Arena::<LocalVariable>::new(),
arguments: Vec::new(),
scopes: vec![FastHashMap::default()],
lookup_global_var_exps: FastHashMap::default(),
lookup_constant_exps: FastHashMap::default(),
typifier: Typifier::new(),
},
module: Module::generate_empty(),
@@ -76,6 +79,21 @@ impl Program {
Ok(()) => Ok(self.context.typifier.get(handle, &self.module.types)),
}
}
pub fn solve_constant(
&mut self,
root: Handle<Expression>,
) -> Result<Handle<Constant>, ErrorKind> {
let mut solver = ConstantSolver {
types: &self.module.types,
expressions: &self.context.expressions,
constants: &mut self.module.constants,
};
solver
.solve(root)
.map_err(|_| ErrorKind::SemanticError("Can't solve constant".into()))
}
}
#[derive(Debug)]
@@ -91,6 +109,7 @@ pub struct Context {
//TODO: Find less allocation heavy representation
pub scopes: Vec<FastHashMap<String, Handle<Expression>>>,
pub lookup_global_var_exps: FastHashMap<String, Handle<Expression>>,
pub lookup_constant_exps: FastHashMap<String, Handle<Expression>>,
pub typifier: Typifier,
}
@@ -154,7 +173,7 @@ impl ExpressionRule {
#[derive(Debug)]
pub enum TypeQualifier {
StorageClass(StorageClass),
StorageQualifier(StorageQualifier),
Binding(Binding),
Interpolation(Interpolation),
}
@@ -177,3 +196,9 @@ pub struct FunctionCall {
pub kind: FunctionCallKind,
pub args: Vec<ExpressionRule>,
}
#[derive(Debug, Clone, Copy)]
pub enum StorageQualifier {
StorageClass(StorageClass),
Const,
}

View File

@@ -179,6 +179,16 @@ impl Program {
}
}
}
for (handle, constant) in self.module.constants.iter() {
if let Some(name) = constant.name.as_ref() {
let expr = self
.context
.expressions
.append(Expression::Constant(handle));
self.context.lookup_constant_exps.insert(name.clone(), expr);
}
}
}
pub fn function_definition(&mut self, mut f: Function, mut block: Block) -> Function {

View File

@@ -182,6 +182,7 @@ impl<'a> Lexer<'a> {
"for" => Some(Token::For(meta)),
// types
"void" => Some(Token::Void(meta)),
"const" => Some(Token::Const(meta)),
word => {
let token = match parse_type(word) {
Some(t) => Token::TypeName((meta, t)),

View File

@@ -114,7 +114,7 @@ pomelo! {
%type layout_qualifier_id (String, u32);
%type type_qualifier Vec<TypeQualifier>;
%type single_type_qualifier TypeQualifier;
%type storage_qualifier StorageClass;
%type storage_qualifier StorageQualifier;
%type interpolation_qualifier Interpolation;
%type Interpolation Interpolation;
@@ -627,7 +627,7 @@ pomelo! {
}
single_type_qualifier ::= storage_qualifier(s) {
TypeQualifier::StorageClass(s)
TypeQualifier::StorageQualifier(s)
}
single_type_qualifier ::= layout_qualifier(l) {
TypeQualifier::Binding(l)
@@ -639,19 +639,21 @@ pomelo! {
// single_type_qualifier ::= invariant_qualifier;
// single_type_qualifier ::= precise_qualifier;
// storage_qualifier ::= Const
storage_qualifier ::= Const {
StorageQualifier::Const
}
// storage_qualifier ::= InOut;
storage_qualifier ::= In {
StorageClass::Input
StorageQualifier::StorageClass(StorageClass::Input)
}
storage_qualifier ::= Out {
StorageClass::Output
StorageQualifier::StorageClass(StorageClass::Output)
}
// storage_qualifier ::= Centroid;
// storage_qualifier ::= Patch;
// storage_qualifier ::= Sample;
storage_qualifier ::= Uniform {
StorageClass::Uniform
StorageQualifier::StorageClass(StorageClass::Uniform)
}
//TODO: other storage qualifiers
@@ -1062,32 +1064,63 @@ pomelo! {
}
external_declaration ::= declaration(d) {
if let Some(d) = d {
let class = d.type_qualifiers.iter().find_map(|tq| {
if let TypeQualifier::StorageClass(sc) = tq { Some(*sc) } else { None }
}).unwrap_or(StorageClass::Private);
// TODO: handle multiple storage qualifiers
let storage = d.type_qualifiers.iter().find_map(|tq| {
if let TypeQualifier::StorageQualifier(sc) = tq { Some(*sc) } else { None }
}).unwrap_or(StorageQualifier::StorageClass(StorageClass::Private));
let binding = d.type_qualifiers.iter().find_map(|tq| {
if let TypeQualifier::Binding(b) = tq { Some(b.clone()) } else { None }
});
match storage {
StorageQualifier::StorageClass(class) => {
// TODO: Check that the storage qualifiers allow for the bindings
let binding = d.type_qualifiers.iter().find_map(|tq| {
if let TypeQualifier::Binding(b) = tq { Some(b.clone()) } else { None }
});
let interpolation = d.type_qualifiers.iter().find_map(|tq| {
if let TypeQualifier::Interpolation(i) = tq { Some(*i) } else { None }
});
let interpolation = d.type_qualifiers.iter().find_map(|tq| {
if let TypeQualifier::Interpolation(i) = tq { Some(*i) } else { None }
});
for (id, initializer) in d.ids_initializers {
let h = extra.module.global_variables.fetch_or_append(
GlobalVariable {
name: id.clone(),
class,
binding: binding.clone(),
ty: d.ty,
init: None,
interpolation,
storage_access: StorageAccess::empty(), //TODO
},
);
if let Some(id) = id {
extra.lookup_global_variables.insert(id, h);
for (id, initializer) in d.ids_initializers {
let init = initializer.map(|init| extra.solve_constant(init.expression)).transpose()?;
let h = extra.module.global_variables.fetch_or_append(
GlobalVariable {
name: id.clone(),
class,
binding: binding.clone(),
ty: d.ty,
init,
interpolation,
storage_access: StorageAccess::empty(), //TODO
},
);
if let Some(id) = id {
extra.lookup_global_variables.insert(id, h);
}
}
}
StorageQualifier::Const => {
for (id, initializer) in d.ids_initializers {
if let Some(init) = initializer {
let constant = extra.solve_constant(init.expression)?;
let inner = extra.module.constants[constant].inner.clone();
let h = extra.module.constants.fetch_or_append(
Constant {
name: id.clone(),
specialization: None, // TODO
inner
},
);
if let Some(id) = id {
extra.lookup_constants.insert(id.clone(), h);
let expr = extra.context.expressions.append(Expression::Constant(h));
extra.context.lookup_constant_exps.insert(id, expr);
}
} else {
return Err(ErrorKind::SemanticError("Constants must have an initalizer".into()))
}
}
}
}
}

View File

@@ -231,3 +231,59 @@ fn functions() {
)
.unwrap();
}
#[test]
fn constants() {
use crate::{Constant, ConstantInner, ScalarValue};
let program = parse_program(
r#"
# version 450
const float a = 1.0;
float global = a;
const flat float b = a;
"#,
ShaderStage::Vertex,
)
.unwrap();
let mut constants = program.module.constants.iter();
assert_eq!(
constants.next().unwrap().1,
&Constant {
name: None,
specialization: None,
inner: ConstantInner::Scalar {
width: 4,
value: ScalarValue::Float(1.0)
}
}
);
assert_eq!(
constants.next().unwrap().1,
&Constant {
name: Some(String::from("a")),
specialization: None,
inner: ConstantInner::Scalar {
width: 4,
value: ScalarValue::Float(1.0)
}
}
);
assert_eq!(
constants.next().unwrap().1,
&Constant {
name: Some(String::from("b")),
specialization: None,
inner: ConstantInner::Scalar {
width: 4,
value: ScalarValue::Float(1.0)
}
}
);
assert!(constants.next().is_none());
}

View File

@@ -15,6 +15,9 @@ impl Program {
if let Some(global_var) = self.context.lookup_global_var_exps.get(name) {
return Ok(Some(*global_var));
}
if let Some(constant) = self.context.lookup_constant_exps.get(name) {
return Ok(Some(*constant));
}
match name {
"gl_Position" => {
#[cfg(feature = "glsl-validate")]

View File

@@ -401,7 +401,7 @@ pub struct Constant {
}
/// A literal scalar value, used in constants.
#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Clone)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum ScalarValue {
@@ -412,7 +412,7 @@ pub enum ScalarValue {
}
/// Additional information, dependendent on the kind of constant.
#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Clone)]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
pub enum ConstantInner {

157
src/proc/constants.rs Normal file
View File

@@ -0,0 +1,157 @@
use crate::{
arena::{Arena, Handle},
ArraySize, Constant, ConstantInner, Expression, ScalarValue, Type,
};
#[derive(Debug)]
pub struct ConstantSolver<'a> {
pub types: &'a Arena<Type>,
pub expressions: &'a Arena<Expression>,
pub constants: &'a mut Arena<Constant>,
}
#[derive(Clone, Debug, PartialEq, thiserror::Error)]
pub enum ConstantSolvingError {
#[error("Constants cannot access function arguments")]
FunctionArg,
#[error("Constants cannot access global variables")]
GlobalVariable,
#[error("Constants cannot access local variables")]
LocalVariable,
#[error("Cannot get the array length of a non array type")]
InvalidArrayLengthArg,
#[error("Constants cannot get the array length of a dynamically sized array")]
ArrayLengthDynamic,
#[error("Constants cannot call functions")]
Call,
#[error("Constants don't support relational functions")]
Relational,
#[error("Constants don't support derivative functions")]
Derivative,
#[error("Constants don't support select expressions")]
Select,
#[error("Constants don't support load expressions")]
Load,
#[error("Constants don't support image expressions")]
ImageExpression,
#[error("Cannot access the type")]
InvalidAccessBase,
#[error("Cannot access at the index")]
InvalidAccessIndex,
#[error("Cannot access with index of type")]
InvalidAccessIndexTy,
}
impl<'a> ConstantSolver<'a> {
pub fn solve(
&mut self,
expr: Handle<Expression>,
) -> Result<Handle<Constant>, ConstantSolvingError> {
match self.expressions[expr] {
Expression::Constant(constant) => Ok(constant),
Expression::AccessIndex { base, index } => self.access(base, index as usize),
Expression::Access { base, index } => {
let index = self.solve(index)?;
self.access(base, self.constant_index(index)?)
}
Expression::Compose { ty, ref components } => {
let components = components
.iter()
.map(|c| self.solve(*c))
.collect::<Result<_, _>>()?;
Ok(self.constants.fetch_or_append(Constant {
name: None,
specialization: None,
inner: ConstantInner::Composite { ty, components },
}))
}
Expression::Unary { .. } => todo!(),
Expression::Binary { .. } => todo!(),
Expression::Math { .. } => todo!(),
Expression::As { .. } => todo!(),
Expression::ArrayLength(expr) => {
let array = self.solve(expr)?;
match self.constants[array].inner {
crate::ConstantInner::Scalar { .. } => {
Err(ConstantSolvingError::InvalidArrayLengthArg)
}
crate::ConstantInner::Composite { ty, .. } => match self.types[ty].inner {
crate::TypeInner::Array { size, .. } => match size {
crate::ArraySize::Constant(constant) => Ok(constant),
crate::ArraySize::Dynamic => {
Err(ConstantSolvingError::ArrayLengthDynamic)
}
},
_ => Err(ConstantSolvingError::InvalidArrayLengthArg),
},
}
}
Expression::Load { .. } => Err(ConstantSolvingError::Load),
Expression::Select { .. } => Err(ConstantSolvingError::Select),
Expression::LocalVariable(_) => Err(ConstantSolvingError::LocalVariable),
Expression::Derivative { .. } => Err(ConstantSolvingError::Derivative),
Expression::Relational { .. } => Err(ConstantSolvingError::Relational),
Expression::Call { .. } => Err(ConstantSolvingError::Call),
Expression::FunctionArgument(_) => Err(ConstantSolvingError::FunctionArg),
Expression::GlobalVariable(_) => Err(ConstantSolvingError::GlobalVariable),
Expression::ImageSample { .. } => Err(ConstantSolvingError::ImageExpression),
Expression::ImageLoad { .. } => Err(ConstantSolvingError::ImageExpression),
}
}
fn access(
&mut self,
base: Handle<Expression>,
index: usize,
) -> Result<Handle<Constant>, ConstantSolvingError> {
let base = self.solve(base)?;
match self.constants[base].inner {
crate::ConstantInner::Scalar { .. } => Err(ConstantSolvingError::InvalidAccessBase),
crate::ConstantInner::Composite { ty, ref components } => match self.types[ty].inner {
crate::TypeInner::Vector { size, .. } => {
if size as usize <= index {
Err(ConstantSolvingError::InvalidAccessIndex)
} else {
Ok(components[index])
}
}
crate::TypeInner::Matrix { .. } => todo!(),
crate::TypeInner::Array { size, .. } => match size {
ArraySize::Constant(constant) => {
let size = self.constant_index(constant)?;
if size <= index {
Err(ConstantSolvingError::InvalidAccessIndex)
} else {
Ok(components[index])
}
}
ArraySize::Dynamic => Err(ConstantSolvingError::ArrayLengthDynamic),
},
crate::TypeInner::Struct { ref members, .. } => {
if members.len() <= index {
Err(ConstantSolvingError::InvalidAccessIndex)
} else {
Ok(components[index])
}
}
_ => Err(ConstantSolvingError::InvalidAccessBase),
},
}
}
fn constant_index(&self, constant: Handle<Constant>) -> Result<usize, ConstantSolvingError> {
match self.constants[constant].inner {
ConstantInner::Scalar {
value: ScalarValue::Uint(index),
..
} => Ok(index as usize),
_ => Err(ConstantSolvingError::InvalidAccessIndexTy),
}
}
}

View File

@@ -2,6 +2,7 @@
#[cfg(feature = "petgraph")]
mod call_graph;
mod constants;
mod interface;
mod layouter;
mod namer;
@@ -11,6 +12,7 @@ mod validator;
#[cfg(feature = "petgraph")]
pub use call_graph::{CallGraph, CallGraphBuilder};
pub use constants::{ConstantSolver, ConstantSolvingError};
pub use interface::{Interface, Visitor};
pub use layouter::{Alignment, Layouter};
pub use namer::{EntryPointIndex, NameKey, Namer};