mirror of
https://github.com/gfx-rs/wgpu.git
synced 2026-04-22 03:02:01 -04:00
Added support for initalizers and constants
This commit is contained in:
committed by
Dzmitry Malyshau
parent
e05baa2889
commit
68bbf00249
@@ -1,7 +1,7 @@
|
||||
use super::error::ErrorKind;
|
||||
use crate::{
|
||||
proc::{ResolveContext, Typifier},
|
||||
Arena, BinaryOperator, Binding, Expression, FastHashMap, Function, FunctionArgument,
|
||||
proc::{ConstantSolver, ResolveContext, Typifier},
|
||||
Arena, BinaryOperator, Binding, Constant, Expression, FastHashMap, Function, FunctionArgument,
|
||||
GlobalVariable, Handle, Interpolation, LocalVariable, Module, ShaderStage, Statement,
|
||||
StorageClass, Type,
|
||||
};
|
||||
@@ -15,6 +15,7 @@ pub struct Program {
|
||||
pub lookup_function: FastHashMap<String, Handle<Function>>,
|
||||
pub lookup_type: FastHashMap<String, Handle<Type>>,
|
||||
pub lookup_global_variables: FastHashMap<String, Handle<GlobalVariable>>,
|
||||
pub lookup_constants: FastHashMap<String, Handle<Constant>>,
|
||||
pub context: Context,
|
||||
pub module: Module,
|
||||
}
|
||||
@@ -29,12 +30,14 @@ impl Program {
|
||||
lookup_function: FastHashMap::default(),
|
||||
lookup_type: FastHashMap::default(),
|
||||
lookup_global_variables: FastHashMap::default(),
|
||||
lookup_constants: FastHashMap::default(),
|
||||
context: Context {
|
||||
expressions: Arena::<Expression>::new(),
|
||||
local_variables: Arena::<LocalVariable>::new(),
|
||||
arguments: Vec::new(),
|
||||
scopes: vec![FastHashMap::default()],
|
||||
lookup_global_var_exps: FastHashMap::default(),
|
||||
lookup_constant_exps: FastHashMap::default(),
|
||||
typifier: Typifier::new(),
|
||||
},
|
||||
module: Module::generate_empty(),
|
||||
@@ -76,6 +79,21 @@ impl Program {
|
||||
Ok(()) => Ok(self.context.typifier.get(handle, &self.module.types)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn solve_constant(
|
||||
&mut self,
|
||||
root: Handle<Expression>,
|
||||
) -> Result<Handle<Constant>, ErrorKind> {
|
||||
let mut solver = ConstantSolver {
|
||||
types: &self.module.types,
|
||||
expressions: &self.context.expressions,
|
||||
constants: &mut self.module.constants,
|
||||
};
|
||||
|
||||
solver
|
||||
.solve(root)
|
||||
.map_err(|_| ErrorKind::SemanticError("Can't solve constant".into()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -91,6 +109,7 @@ pub struct Context {
|
||||
//TODO: Find less allocation heavy representation
|
||||
pub scopes: Vec<FastHashMap<String, Handle<Expression>>>,
|
||||
pub lookup_global_var_exps: FastHashMap<String, Handle<Expression>>,
|
||||
pub lookup_constant_exps: FastHashMap<String, Handle<Expression>>,
|
||||
pub typifier: Typifier,
|
||||
}
|
||||
|
||||
@@ -154,7 +173,7 @@ impl ExpressionRule {
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum TypeQualifier {
|
||||
StorageClass(StorageClass),
|
||||
StorageQualifier(StorageQualifier),
|
||||
Binding(Binding),
|
||||
Interpolation(Interpolation),
|
||||
}
|
||||
@@ -177,3 +196,9 @@ pub struct FunctionCall {
|
||||
pub kind: FunctionCallKind,
|
||||
pub args: Vec<ExpressionRule>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum StorageQualifier {
|
||||
StorageClass(StorageClass),
|
||||
Const,
|
||||
}
|
||||
|
||||
@@ -179,6 +179,16 @@ impl Program {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (handle, constant) in self.module.constants.iter() {
|
||||
if let Some(name) = constant.name.as_ref() {
|
||||
let expr = self
|
||||
.context
|
||||
.expressions
|
||||
.append(Expression::Constant(handle));
|
||||
self.context.lookup_constant_exps.insert(name.clone(), expr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn function_definition(&mut self, mut f: Function, mut block: Block) -> Function {
|
||||
|
||||
@@ -182,6 +182,7 @@ impl<'a> Lexer<'a> {
|
||||
"for" => Some(Token::For(meta)),
|
||||
// types
|
||||
"void" => Some(Token::Void(meta)),
|
||||
"const" => Some(Token::Const(meta)),
|
||||
word => {
|
||||
let token = match parse_type(word) {
|
||||
Some(t) => Token::TypeName((meta, t)),
|
||||
|
||||
@@ -114,7 +114,7 @@ pomelo! {
|
||||
%type layout_qualifier_id (String, u32);
|
||||
%type type_qualifier Vec<TypeQualifier>;
|
||||
%type single_type_qualifier TypeQualifier;
|
||||
%type storage_qualifier StorageClass;
|
||||
%type storage_qualifier StorageQualifier;
|
||||
%type interpolation_qualifier Interpolation;
|
||||
%type Interpolation Interpolation;
|
||||
|
||||
@@ -627,7 +627,7 @@ pomelo! {
|
||||
}
|
||||
|
||||
single_type_qualifier ::= storage_qualifier(s) {
|
||||
TypeQualifier::StorageClass(s)
|
||||
TypeQualifier::StorageQualifier(s)
|
||||
}
|
||||
single_type_qualifier ::= layout_qualifier(l) {
|
||||
TypeQualifier::Binding(l)
|
||||
@@ -639,19 +639,21 @@ pomelo! {
|
||||
// single_type_qualifier ::= invariant_qualifier;
|
||||
// single_type_qualifier ::= precise_qualifier;
|
||||
|
||||
// storage_qualifier ::= Const
|
||||
storage_qualifier ::= Const {
|
||||
StorageQualifier::Const
|
||||
}
|
||||
// storage_qualifier ::= InOut;
|
||||
storage_qualifier ::= In {
|
||||
StorageClass::Input
|
||||
StorageQualifier::StorageClass(StorageClass::Input)
|
||||
}
|
||||
storage_qualifier ::= Out {
|
||||
StorageClass::Output
|
||||
StorageQualifier::StorageClass(StorageClass::Output)
|
||||
}
|
||||
// storage_qualifier ::= Centroid;
|
||||
// storage_qualifier ::= Patch;
|
||||
// storage_qualifier ::= Sample;
|
||||
storage_qualifier ::= Uniform {
|
||||
StorageClass::Uniform
|
||||
StorageQualifier::StorageClass(StorageClass::Uniform)
|
||||
}
|
||||
//TODO: other storage qualifiers
|
||||
|
||||
@@ -1062,32 +1064,63 @@ pomelo! {
|
||||
}
|
||||
external_declaration ::= declaration(d) {
|
||||
if let Some(d) = d {
|
||||
let class = d.type_qualifiers.iter().find_map(|tq| {
|
||||
if let TypeQualifier::StorageClass(sc) = tq { Some(*sc) } else { None }
|
||||
}).unwrap_or(StorageClass::Private);
|
||||
// TODO: handle multiple storage qualifiers
|
||||
let storage = d.type_qualifiers.iter().find_map(|tq| {
|
||||
if let TypeQualifier::StorageQualifier(sc) = tq { Some(*sc) } else { None }
|
||||
}).unwrap_or(StorageQualifier::StorageClass(StorageClass::Private));
|
||||
|
||||
let binding = d.type_qualifiers.iter().find_map(|tq| {
|
||||
if let TypeQualifier::Binding(b) = tq { Some(b.clone()) } else { None }
|
||||
});
|
||||
match storage {
|
||||
StorageQualifier::StorageClass(class) => {
|
||||
// TODO: Check that the storage qualifiers allow for the bindings
|
||||
let binding = d.type_qualifiers.iter().find_map(|tq| {
|
||||
if let TypeQualifier::Binding(b) = tq { Some(b.clone()) } else { None }
|
||||
});
|
||||
|
||||
let interpolation = d.type_qualifiers.iter().find_map(|tq| {
|
||||
if let TypeQualifier::Interpolation(i) = tq { Some(*i) } else { None }
|
||||
});
|
||||
let interpolation = d.type_qualifiers.iter().find_map(|tq| {
|
||||
if let TypeQualifier::Interpolation(i) = tq { Some(*i) } else { None }
|
||||
});
|
||||
|
||||
for (id, initializer) in d.ids_initializers {
|
||||
let h = extra.module.global_variables.fetch_or_append(
|
||||
GlobalVariable {
|
||||
name: id.clone(),
|
||||
class,
|
||||
binding: binding.clone(),
|
||||
ty: d.ty,
|
||||
init: None,
|
||||
interpolation,
|
||||
storage_access: StorageAccess::empty(), //TODO
|
||||
},
|
||||
);
|
||||
if let Some(id) = id {
|
||||
extra.lookup_global_variables.insert(id, h);
|
||||
for (id, initializer) in d.ids_initializers {
|
||||
let init = initializer.map(|init| extra.solve_constant(init.expression)).transpose()?;
|
||||
|
||||
let h = extra.module.global_variables.fetch_or_append(
|
||||
GlobalVariable {
|
||||
name: id.clone(),
|
||||
class,
|
||||
binding: binding.clone(),
|
||||
ty: d.ty,
|
||||
init,
|
||||
interpolation,
|
||||
storage_access: StorageAccess::empty(), //TODO
|
||||
},
|
||||
);
|
||||
if let Some(id) = id {
|
||||
extra.lookup_global_variables.insert(id, h);
|
||||
}
|
||||
}
|
||||
}
|
||||
StorageQualifier::Const => {
|
||||
for (id, initializer) in d.ids_initializers {
|
||||
if let Some(init) = initializer {
|
||||
let constant = extra.solve_constant(init.expression)?;
|
||||
let inner = extra.module.constants[constant].inner.clone();
|
||||
|
||||
let h = extra.module.constants.fetch_or_append(
|
||||
Constant {
|
||||
name: id.clone(),
|
||||
specialization: None, // TODO
|
||||
inner
|
||||
},
|
||||
);
|
||||
if let Some(id) = id {
|
||||
extra.lookup_constants.insert(id.clone(), h);
|
||||
let expr = extra.context.expressions.append(Expression::Constant(h));
|
||||
extra.context.lookup_constant_exps.insert(id, expr);
|
||||
}
|
||||
} else {
|
||||
return Err(ErrorKind::SemanticError("Constants must have an initalizer".into()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -231,3 +231,59 @@ fn functions() {
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn constants() {
|
||||
use crate::{Constant, ConstantInner, ScalarValue};
|
||||
|
||||
let program = parse_program(
|
||||
r#"
|
||||
# version 450
|
||||
const float a = 1.0;
|
||||
float global = a;
|
||||
const flat float b = a;
|
||||
"#,
|
||||
ShaderStage::Vertex,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let mut constants = program.module.constants.iter();
|
||||
|
||||
assert_eq!(
|
||||
constants.next().unwrap().1,
|
||||
&Constant {
|
||||
name: None,
|
||||
specialization: None,
|
||||
inner: ConstantInner::Scalar {
|
||||
width: 4,
|
||||
value: ScalarValue::Float(1.0)
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
constants.next().unwrap().1,
|
||||
&Constant {
|
||||
name: Some(String::from("a")),
|
||||
specialization: None,
|
||||
inner: ConstantInner::Scalar {
|
||||
width: 4,
|
||||
value: ScalarValue::Float(1.0)
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
constants.next().unwrap().1,
|
||||
&Constant {
|
||||
name: Some(String::from("b")),
|
||||
specialization: None,
|
||||
inner: ConstantInner::Scalar {
|
||||
width: 4,
|
||||
value: ScalarValue::Float(1.0)
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
assert!(constants.next().is_none());
|
||||
}
|
||||
|
||||
@@ -15,6 +15,9 @@ impl Program {
|
||||
if let Some(global_var) = self.context.lookup_global_var_exps.get(name) {
|
||||
return Ok(Some(*global_var));
|
||||
}
|
||||
if let Some(constant) = self.context.lookup_constant_exps.get(name) {
|
||||
return Ok(Some(*constant));
|
||||
}
|
||||
match name {
|
||||
"gl_Position" => {
|
||||
#[cfg(feature = "glsl-validate")]
|
||||
|
||||
@@ -401,7 +401,7 @@ pub struct Constant {
|
||||
}
|
||||
|
||||
/// A literal scalar value, used in constants.
|
||||
#[derive(Debug, PartialEq)]
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
pub enum ScalarValue {
|
||||
@@ -412,7 +412,7 @@ pub enum ScalarValue {
|
||||
}
|
||||
|
||||
/// Additional information, dependendent on the kind of constant.
|
||||
#[derive(Debug, PartialEq)]
|
||||
#[derive(Debug, PartialEq, Clone)]
|
||||
#[cfg_attr(feature = "serialize", derive(Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
|
||||
pub enum ConstantInner {
|
||||
|
||||
157
src/proc/constants.rs
Normal file
157
src/proc/constants.rs
Normal file
@@ -0,0 +1,157 @@
|
||||
use crate::{
|
||||
arena::{Arena, Handle},
|
||||
ArraySize, Constant, ConstantInner, Expression, ScalarValue, Type,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ConstantSolver<'a> {
|
||||
pub types: &'a Arena<Type>,
|
||||
pub expressions: &'a Arena<Expression>,
|
||||
pub constants: &'a mut Arena<Constant>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, thiserror::Error)]
|
||||
pub enum ConstantSolvingError {
|
||||
#[error("Constants cannot access function arguments")]
|
||||
FunctionArg,
|
||||
#[error("Constants cannot access global variables")]
|
||||
GlobalVariable,
|
||||
#[error("Constants cannot access local variables")]
|
||||
LocalVariable,
|
||||
#[error("Cannot get the array length of a non array type")]
|
||||
InvalidArrayLengthArg,
|
||||
#[error("Constants cannot get the array length of a dynamically sized array")]
|
||||
ArrayLengthDynamic,
|
||||
#[error("Constants cannot call functions")]
|
||||
Call,
|
||||
#[error("Constants don't support relational functions")]
|
||||
Relational,
|
||||
#[error("Constants don't support derivative functions")]
|
||||
Derivative,
|
||||
#[error("Constants don't support select expressions")]
|
||||
Select,
|
||||
#[error("Constants don't support load expressions")]
|
||||
Load,
|
||||
#[error("Constants don't support image expressions")]
|
||||
ImageExpression,
|
||||
#[error("Cannot access the type")]
|
||||
InvalidAccessBase,
|
||||
#[error("Cannot access at the index")]
|
||||
InvalidAccessIndex,
|
||||
#[error("Cannot access with index of type")]
|
||||
InvalidAccessIndexTy,
|
||||
}
|
||||
|
||||
impl<'a> ConstantSolver<'a> {
|
||||
pub fn solve(
|
||||
&mut self,
|
||||
expr: Handle<Expression>,
|
||||
) -> Result<Handle<Constant>, ConstantSolvingError> {
|
||||
match self.expressions[expr] {
|
||||
Expression::Constant(constant) => Ok(constant),
|
||||
Expression::AccessIndex { base, index } => self.access(base, index as usize),
|
||||
Expression::Access { base, index } => {
|
||||
let index = self.solve(index)?;
|
||||
|
||||
self.access(base, self.constant_index(index)?)
|
||||
}
|
||||
Expression::Compose { ty, ref components } => {
|
||||
let components = components
|
||||
.iter()
|
||||
.map(|c| self.solve(*c))
|
||||
.collect::<Result<_, _>>()?;
|
||||
|
||||
Ok(self.constants.fetch_or_append(Constant {
|
||||
name: None,
|
||||
specialization: None,
|
||||
inner: ConstantInner::Composite { ty, components },
|
||||
}))
|
||||
}
|
||||
Expression::Unary { .. } => todo!(),
|
||||
Expression::Binary { .. } => todo!(),
|
||||
Expression::Math { .. } => todo!(),
|
||||
Expression::As { .. } => todo!(),
|
||||
Expression::ArrayLength(expr) => {
|
||||
let array = self.solve(expr)?;
|
||||
|
||||
match self.constants[array].inner {
|
||||
crate::ConstantInner::Scalar { .. } => {
|
||||
Err(ConstantSolvingError::InvalidArrayLengthArg)
|
||||
}
|
||||
crate::ConstantInner::Composite { ty, .. } => match self.types[ty].inner {
|
||||
crate::TypeInner::Array { size, .. } => match size {
|
||||
crate::ArraySize::Constant(constant) => Ok(constant),
|
||||
crate::ArraySize::Dynamic => {
|
||||
Err(ConstantSolvingError::ArrayLengthDynamic)
|
||||
}
|
||||
},
|
||||
_ => Err(ConstantSolvingError::InvalidArrayLengthArg),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
Expression::Load { .. } => Err(ConstantSolvingError::Load),
|
||||
Expression::Select { .. } => Err(ConstantSolvingError::Select),
|
||||
Expression::LocalVariable(_) => Err(ConstantSolvingError::LocalVariable),
|
||||
Expression::Derivative { .. } => Err(ConstantSolvingError::Derivative),
|
||||
Expression::Relational { .. } => Err(ConstantSolvingError::Relational),
|
||||
Expression::Call { .. } => Err(ConstantSolvingError::Call),
|
||||
Expression::FunctionArgument(_) => Err(ConstantSolvingError::FunctionArg),
|
||||
Expression::GlobalVariable(_) => Err(ConstantSolvingError::GlobalVariable),
|
||||
Expression::ImageSample { .. } => Err(ConstantSolvingError::ImageExpression),
|
||||
Expression::ImageLoad { .. } => Err(ConstantSolvingError::ImageExpression),
|
||||
}
|
||||
}
|
||||
|
||||
fn access(
|
||||
&mut self,
|
||||
base: Handle<Expression>,
|
||||
index: usize,
|
||||
) -> Result<Handle<Constant>, ConstantSolvingError> {
|
||||
let base = self.solve(base)?;
|
||||
|
||||
match self.constants[base].inner {
|
||||
crate::ConstantInner::Scalar { .. } => Err(ConstantSolvingError::InvalidAccessBase),
|
||||
crate::ConstantInner::Composite { ty, ref components } => match self.types[ty].inner {
|
||||
crate::TypeInner::Vector { size, .. } => {
|
||||
if size as usize <= index {
|
||||
Err(ConstantSolvingError::InvalidAccessIndex)
|
||||
} else {
|
||||
Ok(components[index])
|
||||
}
|
||||
}
|
||||
crate::TypeInner::Matrix { .. } => todo!(),
|
||||
crate::TypeInner::Array { size, .. } => match size {
|
||||
ArraySize::Constant(constant) => {
|
||||
let size = self.constant_index(constant)?;
|
||||
|
||||
if size <= index {
|
||||
Err(ConstantSolvingError::InvalidAccessIndex)
|
||||
} else {
|
||||
Ok(components[index])
|
||||
}
|
||||
}
|
||||
ArraySize::Dynamic => Err(ConstantSolvingError::ArrayLengthDynamic),
|
||||
},
|
||||
crate::TypeInner::Struct { ref members, .. } => {
|
||||
if members.len() <= index {
|
||||
Err(ConstantSolvingError::InvalidAccessIndex)
|
||||
} else {
|
||||
Ok(components[index])
|
||||
}
|
||||
}
|
||||
_ => Err(ConstantSolvingError::InvalidAccessBase),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn constant_index(&self, constant: Handle<Constant>) -> Result<usize, ConstantSolvingError> {
|
||||
match self.constants[constant].inner {
|
||||
ConstantInner::Scalar {
|
||||
value: ScalarValue::Uint(index),
|
||||
..
|
||||
} => Ok(index as usize),
|
||||
_ => Err(ConstantSolvingError::InvalidAccessIndexTy),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
#[cfg(feature = "petgraph")]
|
||||
mod call_graph;
|
||||
mod constants;
|
||||
mod interface;
|
||||
mod layouter;
|
||||
mod namer;
|
||||
@@ -11,6 +12,7 @@ mod validator;
|
||||
|
||||
#[cfg(feature = "petgraph")]
|
||||
pub use call_graph::{CallGraph, CallGraphBuilder};
|
||||
pub use constants::{ConstantSolver, ConstantSolvingError};
|
||||
pub use interface::{Interface, Visitor};
|
||||
pub use layouter::{Alignment, Layouter};
|
||||
pub use namer::{EntryPointIndex, NameKey, Namer};
|
||||
|
||||
Reference in New Issue
Block a user