mirror of
https://github.com/gfx-rs/wgpu.git
synced 2026-04-22 03:02:01 -04:00
implement override-expression evaluation for initializers of override declarations
This commit is contained in:
@@ -122,6 +122,7 @@ impl<T> Handle<T> {
|
||||
serde(transparent)
|
||||
)]
|
||||
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
pub struct Range<T> {
|
||||
inner: ops::Range<u32>,
|
||||
#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(skip))]
|
||||
@@ -140,6 +141,7 @@ impl<T> Range<T> {
|
||||
|
||||
// NOTE: Keep this diagnostic in sync with that of [`BadHandle`].
|
||||
#[derive(Clone, Debug, thiserror::Error)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
#[error("Handle range {range:?} of {kind} is either not present, or inaccessible yet")]
|
||||
pub struct BadRangeError {
|
||||
// This error is used for many `Handle` types, but there's no point in making this generic, so
|
||||
|
||||
@@ -256,7 +256,7 @@ pub enum Error {
|
||||
#[error("{0}")]
|
||||
Custom(String),
|
||||
#[error(transparent)]
|
||||
PipelineConstant(#[from] back::pipeline_constants::PipelineConstantError),
|
||||
PipelineConstant(#[from] Box<back::pipeline_constants::PipelineConstantError>),
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
|
||||
@@ -169,9 +169,14 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|
||||
module_info: &valid::ModuleInfo,
|
||||
pipeline_options: &PipelineOptions,
|
||||
) -> Result<super::ReflectionInfo, Error> {
|
||||
let module =
|
||||
back::pipeline_constants::process_overrides(module, &pipeline_options.constants)?;
|
||||
let (module, module_info) = back::pipeline_constants::process_overrides(
|
||||
module,
|
||||
module_info,
|
||||
&pipeline_options.constants,
|
||||
)
|
||||
.map_err(Box::new)?;
|
||||
let module = module.as_ref();
|
||||
let module_info = module_info.as_ref();
|
||||
|
||||
self.reset(module);
|
||||
|
||||
|
||||
@@ -144,7 +144,7 @@ pub enum Error {
|
||||
#[error("ray tracing is not supported prior to MSL 2.3")]
|
||||
UnsupportedRayTracing,
|
||||
#[error(transparent)]
|
||||
PipelineConstant(#[from] crate::back::pipeline_constants::PipelineConstantError),
|
||||
PipelineConstant(#[from] Box<crate::back::pipeline_constants::PipelineConstantError>),
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, thiserror::Error)]
|
||||
|
||||
@@ -3223,9 +3223,11 @@ impl<W: Write> Writer<W> {
|
||||
options: &Options,
|
||||
pipeline_options: &PipelineOptions,
|
||||
) -> Result<TranslationInfo, Error> {
|
||||
let module =
|
||||
back::pipeline_constants::process_overrides(module, &pipeline_options.constants)?;
|
||||
let (module, info) =
|
||||
back::pipeline_constants::process_overrides(module, info, &pipeline_options.constants)
|
||||
.map_err(Box::new)?;
|
||||
let module = module.as_ref();
|
||||
let info = info.as_ref();
|
||||
|
||||
self.names.clear();
|
||||
self.namer.reset(
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
use super::PipelineConstants;
|
||||
use crate::{Constant, Expression, Literal, Module, Scalar, Span, TypeInner};
|
||||
use std::borrow::Cow;
|
||||
use crate::{
|
||||
proc::{ConstantEvaluator, ConstantEvaluatorError},
|
||||
valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
|
||||
Constant, Expression, Handle, Literal, Module, Override, Scalar, Span, TypeInner, WithSpan,
|
||||
};
|
||||
use std::{borrow::Cow, collections::HashSet};
|
||||
use thiserror::Error;
|
||||
|
||||
#[derive(Error, Debug, Clone)]
|
||||
@@ -12,48 +16,317 @@ pub enum PipelineConstantError {
|
||||
SrcNeedsToBeFinite,
|
||||
#[error("Source f64 value doesn't fit in destination")]
|
||||
DstRangeTooSmall,
|
||||
#[error(transparent)]
|
||||
ConstantEvaluatorError(#[from] ConstantEvaluatorError),
|
||||
#[error(transparent)]
|
||||
ValidationError(#[from] WithSpan<ValidationError>),
|
||||
}
|
||||
|
||||
pub(super) fn process_overrides<'a>(
|
||||
module: &'a Module,
|
||||
module_info: &'a ModuleInfo,
|
||||
pipeline_constants: &PipelineConstants,
|
||||
) -> Result<Cow<'a, Module>, PipelineConstantError> {
|
||||
) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
|
||||
if module.overrides.is_empty() {
|
||||
return Ok(Cow::Borrowed(module));
|
||||
return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
|
||||
}
|
||||
|
||||
let mut module = module.clone();
|
||||
let mut override_map = Vec::with_capacity(module.overrides.len());
|
||||
let mut adjusted_const_expressions = Vec::with_capacity(module.const_expressions.len());
|
||||
let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
|
||||
|
||||
for (_handle, override_, span) in module.overrides.drain() {
|
||||
let key = if let Some(id) = override_.id {
|
||||
Cow::Owned(id.to_string())
|
||||
} else if let Some(ref name) = override_.name {
|
||||
Cow::Borrowed(name)
|
||||
} else {
|
||||
unreachable!();
|
||||
let mut global_expression_kind_tracker = crate::proc::ExpressionConstnessTracker::new();
|
||||
|
||||
let mut override_iter = module.overrides.drain();
|
||||
|
||||
for (old_h, expr, span) in module.const_expressions.drain() {
|
||||
let mut expr = match expr {
|
||||
Expression::Override(h) => {
|
||||
let c_h = if let Some(new_h) = override_map.get(h.index()) {
|
||||
*new_h
|
||||
} else {
|
||||
let mut new_h = None;
|
||||
for entry in override_iter.by_ref() {
|
||||
let stop = entry.0 == h;
|
||||
new_h = Some(process_override(
|
||||
entry,
|
||||
pipeline_constants,
|
||||
&mut module,
|
||||
&mut override_map,
|
||||
&adjusted_const_expressions,
|
||||
&mut adjusted_constant_initializers,
|
||||
&mut global_expression_kind_tracker,
|
||||
)?);
|
||||
if stop {
|
||||
break;
|
||||
}
|
||||
}
|
||||
new_h.unwrap()
|
||||
};
|
||||
Expression::Constant(c_h)
|
||||
}
|
||||
Expression::Constant(c_h) => {
|
||||
adjusted_constant_initializers.insert(c_h);
|
||||
module.constants[c_h].init = adjusted_const_expressions[c_h.index()];
|
||||
expr
|
||||
}
|
||||
expr => expr,
|
||||
};
|
||||
let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
|
||||
let literal = match module.types[override_.ty].inner {
|
||||
TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
module
|
||||
.const_expressions
|
||||
.append(Expression::Literal(literal), Span::UNDEFINED)
|
||||
} else if let Some(init) = override_.init {
|
||||
init
|
||||
} else {
|
||||
return Err(PipelineConstantError::MissingValue(key.to_string()));
|
||||
};
|
||||
let constant = Constant {
|
||||
name: override_.name,
|
||||
ty: override_.ty,
|
||||
init,
|
||||
};
|
||||
module.constants.append(constant, span);
|
||||
let mut evaluator = ConstantEvaluator::for_wgsl_module(
|
||||
&mut module,
|
||||
&mut global_expression_kind_tracker,
|
||||
false,
|
||||
);
|
||||
adjust_expr(&adjusted_const_expressions, &mut expr);
|
||||
let h = evaluator.try_eval_and_append(expr, span)?;
|
||||
debug_assert_eq!(old_h.index(), adjusted_const_expressions.len());
|
||||
adjusted_const_expressions.push(h);
|
||||
}
|
||||
|
||||
Ok(Cow::Owned(module))
|
||||
for entry in override_iter {
|
||||
process_override(
|
||||
entry,
|
||||
pipeline_constants,
|
||||
&mut module,
|
||||
&mut override_map,
|
||||
&adjusted_const_expressions,
|
||||
&mut adjusted_constant_initializers,
|
||||
&mut global_expression_kind_tracker,
|
||||
)?;
|
||||
}
|
||||
|
||||
for (_, c) in module
|
||||
.constants
|
||||
.iter_mut()
|
||||
.filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
|
||||
{
|
||||
c.init = adjusted_const_expressions[c.init.index()];
|
||||
}
|
||||
|
||||
for (_, v) in module.global_variables.iter_mut() {
|
||||
if let Some(ref mut init) = v.init {
|
||||
*init = adjusted_const_expressions[init.index()];
|
||||
}
|
||||
}
|
||||
|
||||
let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
|
||||
let module_info = validator.validate(&module)?;
|
||||
|
||||
Ok((Cow::Owned(module), Cow::Owned(module_info)))
|
||||
}
|
||||
|
||||
fn process_override(
|
||||
(old_h, override_, span): (Handle<Override>, Override, Span),
|
||||
pipeline_constants: &PipelineConstants,
|
||||
module: &mut Module,
|
||||
override_map: &mut Vec<Handle<Constant>>,
|
||||
adjusted_const_expressions: &[Handle<Expression>],
|
||||
adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
|
||||
global_expression_kind_tracker: &mut crate::proc::ExpressionConstnessTracker,
|
||||
) -> Result<Handle<Constant>, PipelineConstantError> {
|
||||
let key = if let Some(id) = override_.id {
|
||||
Cow::Owned(id.to_string())
|
||||
} else if let Some(ref name) = override_.name {
|
||||
Cow::Borrowed(name)
|
||||
} else {
|
||||
unreachable!();
|
||||
};
|
||||
let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
|
||||
let literal = match module.types[override_.ty].inner {
|
||||
TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let expr = module
|
||||
.const_expressions
|
||||
.append(Expression::Literal(literal), Span::UNDEFINED);
|
||||
global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
|
||||
expr
|
||||
} else if let Some(init) = override_.init {
|
||||
adjusted_const_expressions[init.index()]
|
||||
} else {
|
||||
return Err(PipelineConstantError::MissingValue(key.to_string()));
|
||||
};
|
||||
let constant = Constant {
|
||||
name: override_.name,
|
||||
ty: override_.ty,
|
||||
init,
|
||||
};
|
||||
let h = module.constants.append(constant, span);
|
||||
debug_assert_eq!(old_h.index(), override_map.len());
|
||||
override_map.push(h);
|
||||
adjusted_constant_initializers.insert(h);
|
||||
Ok(h)
|
||||
}
|
||||
|
||||
fn adjust_expr(new_pos: &[Handle<Expression>], expr: &mut Expression) {
|
||||
let adjust = |expr: &mut Handle<Expression>| {
|
||||
*expr = new_pos[expr.index()];
|
||||
};
|
||||
match *expr {
|
||||
Expression::Compose {
|
||||
ref mut components, ..
|
||||
} => {
|
||||
for c in components.iter_mut() {
|
||||
adjust(c);
|
||||
}
|
||||
}
|
||||
Expression::Access {
|
||||
ref mut base,
|
||||
ref mut index,
|
||||
} => {
|
||||
adjust(base);
|
||||
adjust(index);
|
||||
}
|
||||
Expression::AccessIndex { ref mut base, .. } => {
|
||||
adjust(base);
|
||||
}
|
||||
Expression::Splat { ref mut value, .. } => {
|
||||
adjust(value);
|
||||
}
|
||||
Expression::Swizzle { ref mut vector, .. } => {
|
||||
adjust(vector);
|
||||
}
|
||||
Expression::Load { ref mut pointer } => {
|
||||
adjust(pointer);
|
||||
}
|
||||
Expression::ImageSample {
|
||||
ref mut image,
|
||||
ref mut sampler,
|
||||
ref mut coordinate,
|
||||
ref mut array_index,
|
||||
ref mut offset,
|
||||
ref mut level,
|
||||
ref mut depth_ref,
|
||||
..
|
||||
} => {
|
||||
adjust(image);
|
||||
adjust(sampler);
|
||||
adjust(coordinate);
|
||||
if let Some(e) = array_index.as_mut() {
|
||||
adjust(e);
|
||||
}
|
||||
if let Some(e) = offset.as_mut() {
|
||||
adjust(e);
|
||||
}
|
||||
match *level {
|
||||
crate::SampleLevel::Exact(ref mut expr)
|
||||
| crate::SampleLevel::Bias(ref mut expr) => {
|
||||
adjust(expr);
|
||||
}
|
||||
crate::SampleLevel::Gradient {
|
||||
ref mut x,
|
||||
ref mut y,
|
||||
} => {
|
||||
adjust(x);
|
||||
adjust(y);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
if let Some(e) = depth_ref.as_mut() {
|
||||
adjust(e);
|
||||
}
|
||||
}
|
||||
Expression::ImageLoad {
|
||||
ref mut image,
|
||||
ref mut coordinate,
|
||||
ref mut array_index,
|
||||
ref mut sample,
|
||||
ref mut level,
|
||||
} => {
|
||||
adjust(image);
|
||||
adjust(coordinate);
|
||||
if let Some(e) = array_index.as_mut() {
|
||||
adjust(e);
|
||||
}
|
||||
if let Some(e) = sample.as_mut() {
|
||||
adjust(e);
|
||||
}
|
||||
if let Some(e) = level.as_mut() {
|
||||
adjust(e);
|
||||
}
|
||||
}
|
||||
Expression::ImageQuery {
|
||||
ref mut image,
|
||||
ref mut query,
|
||||
} => {
|
||||
adjust(image);
|
||||
match *query {
|
||||
crate::ImageQuery::Size { ref mut level } => {
|
||||
if let Some(e) = level.as_mut() {
|
||||
adjust(e);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Expression::Unary { ref mut expr, .. } => {
|
||||
adjust(expr);
|
||||
}
|
||||
Expression::Binary {
|
||||
ref mut left,
|
||||
ref mut right,
|
||||
..
|
||||
} => {
|
||||
adjust(left);
|
||||
adjust(right);
|
||||
}
|
||||
Expression::Select {
|
||||
ref mut condition,
|
||||
ref mut accept,
|
||||
ref mut reject,
|
||||
} => {
|
||||
adjust(condition);
|
||||
adjust(accept);
|
||||
adjust(reject);
|
||||
}
|
||||
Expression::Derivative { ref mut expr, .. } => {
|
||||
adjust(expr);
|
||||
}
|
||||
Expression::Relational {
|
||||
ref mut argument, ..
|
||||
} => {
|
||||
adjust(argument);
|
||||
}
|
||||
Expression::Math {
|
||||
ref mut arg,
|
||||
ref mut arg1,
|
||||
ref mut arg2,
|
||||
ref mut arg3,
|
||||
..
|
||||
} => {
|
||||
adjust(arg);
|
||||
if let Some(e) = arg1.as_mut() {
|
||||
adjust(e);
|
||||
}
|
||||
if let Some(e) = arg2.as_mut() {
|
||||
adjust(e);
|
||||
}
|
||||
if let Some(e) = arg3.as_mut() {
|
||||
adjust(e);
|
||||
}
|
||||
}
|
||||
Expression::As { ref mut expr, .. } => {
|
||||
adjust(expr);
|
||||
}
|
||||
Expression::ArrayLength(ref mut expr) => {
|
||||
adjust(expr);
|
||||
}
|
||||
Expression::RayQueryGetIntersection { ref mut query, .. } => {
|
||||
adjust(query);
|
||||
}
|
||||
Expression::Literal(_)
|
||||
| Expression::FunctionArgument(_)
|
||||
| Expression::GlobalVariable(_)
|
||||
| Expression::LocalVariable(_)
|
||||
| Expression::CallResult(_)
|
||||
| Expression::RayQueryProceedResult
|
||||
| Expression::Constant(_)
|
||||
| Expression::Override(_)
|
||||
| Expression::ZeroValue(_)
|
||||
| Expression::AtomicResult { .. }
|
||||
| Expression::WorkGroupUniformLoadResult { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
|
||||
|
||||
@@ -71,7 +71,7 @@ pub enum Error {
|
||||
#[error("module is not validated properly: {0}")]
|
||||
Validation(&'static str),
|
||||
#[error(transparent)]
|
||||
PipelineConstant(#[from] crate::back::pipeline_constants::PipelineConstantError),
|
||||
PipelineConstant(#[from] Box<crate::back::pipeline_constants::PipelineConstantError>),
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -529,6 +529,42 @@ struct FunctionArgument {
|
||||
handle_id: Word,
|
||||
}
|
||||
|
||||
/// Tracks the expressions for which the backend emits the following instructions:
|
||||
/// - OpConstantTrue
|
||||
/// - OpConstantFalse
|
||||
/// - OpConstant
|
||||
/// - OpConstantComposite
|
||||
/// - OpConstantNull
|
||||
struct ExpressionConstnessTracker {
|
||||
inner: bit_set::BitSet,
|
||||
}
|
||||
|
||||
impl ExpressionConstnessTracker {
|
||||
fn from_arena(arena: &crate::Arena<crate::Expression>) -> Self {
|
||||
let mut inner = bit_set::BitSet::new();
|
||||
for (handle, expr) in arena.iter() {
|
||||
let insert = match *expr {
|
||||
crate::Expression::Literal(_)
|
||||
| crate::Expression::ZeroValue(_)
|
||||
| crate::Expression::Constant(_) => true,
|
||||
crate::Expression::Compose { ref components, .. } => {
|
||||
components.iter().all(|h| inner.contains(h.index()))
|
||||
}
|
||||
crate::Expression::Splat { value, .. } => inner.contains(value.index()),
|
||||
_ => false,
|
||||
};
|
||||
if insert {
|
||||
inner.insert(handle.index());
|
||||
}
|
||||
}
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
fn is_const(&self, value: Handle<crate::Expression>) -> bool {
|
||||
self.inner.contains(value.index())
|
||||
}
|
||||
}
|
||||
|
||||
/// General information needed to emit SPIR-V for Naga statements.
|
||||
struct BlockContext<'w> {
|
||||
/// The writer handling the module to which this code belongs.
|
||||
@@ -554,7 +590,7 @@ struct BlockContext<'w> {
|
||||
temp_list: Vec<Word>,
|
||||
|
||||
/// Tracks the constness of `Expression`s residing in `self.ir_function.expressions`
|
||||
expression_constness: crate::proc::ExpressionConstnessTracker,
|
||||
expression_constness: ExpressionConstnessTracker,
|
||||
}
|
||||
|
||||
impl BlockContext<'_> {
|
||||
|
||||
@@ -615,7 +615,7 @@ impl Writer {
|
||||
// Steal the Writer's temp list for a bit.
|
||||
temp_list: std::mem::take(&mut self.temp_list),
|
||||
writer: self,
|
||||
expression_constness: crate::proc::ExpressionConstnessTracker::from_arena(
|
||||
expression_constness: super::ExpressionConstnessTracker::from_arena(
|
||||
&ir_function.expressions,
|
||||
),
|
||||
};
|
||||
@@ -2029,15 +2029,21 @@ impl Writer {
|
||||
debug_info: &Option<DebugInfo>,
|
||||
words: &mut Vec<Word>,
|
||||
) -> Result<(), Error> {
|
||||
let ir_module = if let Some(pipeline_options) = pipeline_options {
|
||||
let (ir_module, info) = if let Some(pipeline_options) = pipeline_options {
|
||||
crate::back::pipeline_constants::process_overrides(
|
||||
ir_module,
|
||||
info,
|
||||
&pipeline_options.constants,
|
||||
)?
|
||||
)
|
||||
.map_err(Box::new)?
|
||||
} else {
|
||||
std::borrow::Cow::Borrowed(ir_module)
|
||||
(
|
||||
std::borrow::Cow::Borrowed(ir_module),
|
||||
std::borrow::Cow::Borrowed(info),
|
||||
)
|
||||
};
|
||||
let ir_module = ir_module.as_ref();
|
||||
let info = info.as_ref();
|
||||
|
||||
self.reset();
|
||||
|
||||
|
||||
@@ -77,12 +77,19 @@ pub struct Context<'a> {
|
||||
pub body: Block,
|
||||
pub module: &'a mut crate::Module,
|
||||
pub is_const: bool,
|
||||
/// Tracks the constness of `Expression`s residing in `self.expressions`
|
||||
pub expression_constness: crate::proc::ExpressionConstnessTracker,
|
||||
/// Tracks the expression kind of `Expression`s residing in `self.expressions`
|
||||
pub local_expression_kind_tracker: crate::proc::ExpressionConstnessTracker,
|
||||
/// Tracks the expression kind of `Expression`s residing in `self.module.const_expressions`
|
||||
pub global_expression_kind_tracker: &'a mut crate::proc::ExpressionConstnessTracker,
|
||||
}
|
||||
|
||||
impl<'a> Context<'a> {
|
||||
pub fn new(frontend: &Frontend, module: &'a mut crate::Module, is_const: bool) -> Result<Self> {
|
||||
pub fn new(
|
||||
frontend: &Frontend,
|
||||
module: &'a mut crate::Module,
|
||||
is_const: bool,
|
||||
global_expression_kind_tracker: &'a mut crate::proc::ExpressionConstnessTracker,
|
||||
) -> Result<Self> {
|
||||
let mut this = Context {
|
||||
expressions: Arena::new(),
|
||||
locals: Arena::new(),
|
||||
@@ -101,7 +108,8 @@ impl<'a> Context<'a> {
|
||||
body: Block::new(),
|
||||
module,
|
||||
is_const: false,
|
||||
expression_constness: crate::proc::ExpressionConstnessTracker::new(),
|
||||
local_expression_kind_tracker: crate::proc::ExpressionConstnessTracker::new(),
|
||||
global_expression_kind_tracker,
|
||||
};
|
||||
|
||||
this.emit_start();
|
||||
@@ -249,12 +257,15 @@ impl<'a> Context<'a> {
|
||||
|
||||
pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result<Handle<Expression>> {
|
||||
let mut eval = if self.is_const {
|
||||
crate::proc::ConstantEvaluator::for_glsl_module(self.module)
|
||||
crate::proc::ConstantEvaluator::for_glsl_module(
|
||||
self.module,
|
||||
self.global_expression_kind_tracker,
|
||||
)
|
||||
} else {
|
||||
crate::proc::ConstantEvaluator::for_glsl_function(
|
||||
self.module,
|
||||
&mut self.expressions,
|
||||
&mut self.expression_constness,
|
||||
&mut self.local_expression_kind_tracker,
|
||||
&mut self.emitter,
|
||||
&mut self.body,
|
||||
)
|
||||
|
||||
@@ -1236,6 +1236,8 @@ impl Frontend {
|
||||
let pointer = ctx
|
||||
.expressions
|
||||
.append(Expression::GlobalVariable(arg.handle), Default::default());
|
||||
ctx.local_expression_kind_tracker
|
||||
.insert(pointer, crate::proc::ExpressionKind::Runtime);
|
||||
|
||||
let ty = ctx.module.global_variables[arg.handle].ty;
|
||||
|
||||
@@ -1256,6 +1258,8 @@ impl Frontend {
|
||||
let value = ctx
|
||||
.expressions
|
||||
.append(Expression::FunctionArgument(idx), Default::default());
|
||||
ctx.local_expression_kind_tracker
|
||||
.insert(value, crate::proc::ExpressionKind::Runtime);
|
||||
ctx.body
|
||||
.push(Statement::Store { pointer, value }, Default::default());
|
||||
},
|
||||
@@ -1285,6 +1289,8 @@ impl Frontend {
|
||||
let pointer = ctx
|
||||
.expressions
|
||||
.append(Expression::GlobalVariable(arg.handle), Default::default());
|
||||
ctx.local_expression_kind_tracker
|
||||
.insert(pointer, crate::proc::ExpressionKind::Runtime);
|
||||
|
||||
let ty = ctx.module.global_variables[arg.handle].ty;
|
||||
|
||||
@@ -1307,6 +1313,8 @@ impl Frontend {
|
||||
let load = ctx
|
||||
.expressions
|
||||
.append(Expression::Load { pointer }, Default::default());
|
||||
ctx.local_expression_kind_tracker
|
||||
.insert(load, crate::proc::ExpressionKind::Runtime);
|
||||
ctx.body.push(
|
||||
Statement::Emit(ctx.expressions.range_from(len)),
|
||||
Default::default(),
|
||||
@@ -1329,6 +1337,8 @@ impl Frontend {
|
||||
let res = ctx
|
||||
.expressions
|
||||
.append(Expression::Compose { ty, components }, Default::default());
|
||||
ctx.local_expression_kind_tracker
|
||||
.insert(res, crate::proc::ExpressionKind::Runtime);
|
||||
ctx.body.push(
|
||||
Statement::Emit(ctx.expressions.range_from(len)),
|
||||
Default::default(),
|
||||
|
||||
@@ -164,9 +164,15 @@ impl<'source> ParsingContext<'source> {
|
||||
|
||||
pub fn parse(&mut self, frontend: &mut Frontend) -> Result<Module> {
|
||||
let mut module = Module::default();
|
||||
let mut global_expression_kind_tracker = crate::proc::ExpressionConstnessTracker::new();
|
||||
|
||||
// Body and expression arena for global initialization
|
||||
let mut ctx = Context::new(frontend, &mut module, false)?;
|
||||
let mut ctx = Context::new(
|
||||
frontend,
|
||||
&mut module,
|
||||
false,
|
||||
&mut global_expression_kind_tracker,
|
||||
)?;
|
||||
|
||||
while self.peek(frontend).is_some() {
|
||||
self.parse_external_declaration(frontend, &mut ctx)?;
|
||||
@@ -196,7 +202,11 @@ impl<'source> ParsingContext<'source> {
|
||||
frontend: &mut Frontend,
|
||||
ctx: &mut Context,
|
||||
) -> Result<(u32, Span)> {
|
||||
let (const_expr, meta) = self.parse_constant_expression(frontend, ctx.module)?;
|
||||
let (const_expr, meta) = self.parse_constant_expression(
|
||||
frontend,
|
||||
ctx.module,
|
||||
ctx.global_expression_kind_tracker,
|
||||
)?;
|
||||
|
||||
let res = ctx.module.to_ctx().eval_expr_to_u32(const_expr);
|
||||
|
||||
@@ -219,8 +229,9 @@ impl<'source> ParsingContext<'source> {
|
||||
&mut self,
|
||||
frontend: &mut Frontend,
|
||||
module: &mut Module,
|
||||
global_expression_kind_tracker: &mut crate::proc::ExpressionConstnessTracker,
|
||||
) -> Result<(Handle<Expression>, Span)> {
|
||||
let mut ctx = Context::new(frontend, module, true)?;
|
||||
let mut ctx = Context::new(frontend, module, true, global_expression_kind_tracker)?;
|
||||
|
||||
let mut stmt_ctx = ctx.stmt_ctx();
|
||||
let expr = self.parse_conditional(frontend, &mut ctx, &mut stmt_ctx, None)?;
|
||||
|
||||
@@ -251,7 +251,7 @@ impl<'source> ParsingContext<'source> {
|
||||
init.and_then(|expr| ctx.ctx.lift_up_const_expression(expr).ok());
|
||||
late_initializer = None;
|
||||
} else if let Some(init) = init {
|
||||
if ctx.is_inside_loop || !ctx.ctx.expression_constness.is_const(init) {
|
||||
if ctx.is_inside_loop || !ctx.ctx.local_expression_kind_tracker.is_const(init) {
|
||||
decl_initializer = None;
|
||||
late_initializer = Some(init);
|
||||
} else {
|
||||
@@ -326,7 +326,12 @@ impl<'source> ParsingContext<'source> {
|
||||
|
||||
let result = ty.map(|ty| FunctionResult { ty, binding: None });
|
||||
|
||||
let mut context = Context::new(frontend, ctx.module, false)?;
|
||||
let mut context = Context::new(
|
||||
frontend,
|
||||
ctx.module,
|
||||
false,
|
||||
ctx.global_expression_kind_tracker,
|
||||
)?;
|
||||
|
||||
self.parse_function_args(frontend, &mut context)?;
|
||||
|
||||
|
||||
@@ -192,8 +192,11 @@ impl<'source> ParsingContext<'source> {
|
||||
TokenValue::Case => {
|
||||
self.bump(frontend)?;
|
||||
|
||||
let (const_expr, meta) =
|
||||
self.parse_constant_expression(frontend, ctx.module)?;
|
||||
let (const_expr, meta) = self.parse_constant_expression(
|
||||
frontend,
|
||||
ctx.module,
|
||||
ctx.global_expression_kind_tracker,
|
||||
)?;
|
||||
|
||||
match ctx.module.const_expressions[const_expr] {
|
||||
Expression::Literal(Literal::I32(value)) => match uint {
|
||||
|
||||
@@ -330,7 +330,7 @@ impl Context<'_> {
|
||||
expr: Handle<Expression>,
|
||||
) -> Result<Handle<Expression>> {
|
||||
let meta = self.expressions.get_span(expr);
|
||||
Ok(match self.expressions[expr] {
|
||||
let h = match self.expressions[expr] {
|
||||
ref expr @ (Expression::Literal(_)
|
||||
| Expression::Constant(_)
|
||||
| Expression::ZeroValue(_)) => self.module.const_expressions.append(expr.clone(), meta),
|
||||
@@ -355,6 +355,9 @@ impl Context<'_> {
|
||||
meta,
|
||||
})
|
||||
}
|
||||
})
|
||||
};
|
||||
self.global_expression_kind_tracker
|
||||
.insert(h, crate::proc::ExpressionKind::Const);
|
||||
Ok(h)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -86,6 +86,8 @@ pub struct GlobalContext<'source, 'temp, 'out> {
|
||||
module: &'out mut crate::Module,
|
||||
|
||||
const_typifier: &'temp mut Typifier,
|
||||
|
||||
global_expression_kind_tracker: &'temp mut crate::proc::ExpressionConstnessTracker,
|
||||
}
|
||||
|
||||
impl<'source> GlobalContext<'source, '_, '_> {
|
||||
@@ -97,6 +99,19 @@ impl<'source> GlobalContext<'source, '_, '_> {
|
||||
module: self.module,
|
||||
const_typifier: self.const_typifier,
|
||||
expr_type: ExpressionContextType::Constant,
|
||||
global_expression_kind_tracker: self.global_expression_kind_tracker,
|
||||
}
|
||||
}
|
||||
|
||||
fn as_override(&mut self) -> ExpressionContext<'source, '_, '_> {
|
||||
ExpressionContext {
|
||||
ast_expressions: self.ast_expressions,
|
||||
globals: self.globals,
|
||||
types: self.types,
|
||||
module: self.module,
|
||||
const_typifier: self.const_typifier,
|
||||
expr_type: ExpressionContextType::Override,
|
||||
global_expression_kind_tracker: self.global_expression_kind_tracker,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -165,6 +180,7 @@ pub struct StatementContext<'source, 'temp, 'out> {
|
||||
/// we should consider them to be const. See the use of `force_non_const` in
|
||||
/// the code for lowering `let` bindings.
|
||||
expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker,
|
||||
global_expression_kind_tracker: &'temp mut crate::proc::ExpressionConstnessTracker,
|
||||
}
|
||||
|
||||
impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
|
||||
@@ -181,6 +197,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
|
||||
types: self.types,
|
||||
ast_expressions: self.ast_expressions,
|
||||
const_typifier: self.const_typifier,
|
||||
global_expression_kind_tracker: self.global_expression_kind_tracker,
|
||||
module: self.module,
|
||||
expr_type: ExpressionContextType::Runtime(RuntimeExpressionContext {
|
||||
local_table: self.local_table,
|
||||
@@ -200,6 +217,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
|
||||
types: self.types,
|
||||
module: self.module,
|
||||
const_typifier: self.const_typifier,
|
||||
global_expression_kind_tracker: self.global_expression_kind_tracker,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -253,6 +271,14 @@ pub enum ExpressionContextType<'temp, 'out> {
|
||||
/// available in the [`ExpressionContext`], so this variant
|
||||
/// carries no further information.
|
||||
Constant,
|
||||
|
||||
/// We are lowering to an override expression, to be included in the module's
|
||||
/// constant expression arena.
|
||||
///
|
||||
/// Everything override expressions are allowed to refer to is
|
||||
/// available in the [`ExpressionContext`], so this variant
|
||||
/// carries no further information.
|
||||
Override,
|
||||
}
|
||||
|
||||
/// State for lowering an [`ast::Expression`] to Naga IR.
|
||||
@@ -311,6 +337,7 @@ pub struct ExpressionContext<'source, 'temp, 'out> {
|
||||
///
|
||||
/// [`module::const_expressions`]: crate::Module::const_expressions
|
||||
const_typifier: &'temp mut Typifier,
|
||||
global_expression_kind_tracker: &'temp mut crate::proc::ExpressionConstnessTracker,
|
||||
|
||||
/// Whether we are lowering a constant expression or a general
|
||||
/// runtime expression, and the data needed in each case.
|
||||
@@ -326,6 +353,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
|
||||
const_typifier: self.const_typifier,
|
||||
module: self.module,
|
||||
expr_type: ExpressionContextType::Constant,
|
||||
global_expression_kind_tracker: self.global_expression_kind_tracker,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -336,6 +364,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
|
||||
types: self.types,
|
||||
module: self.module,
|
||||
const_typifier: self.const_typifier,
|
||||
global_expression_kind_tracker: self.global_expression_kind_tracker,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -348,7 +377,16 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
|
||||
rctx.emitter,
|
||||
rctx.block,
|
||||
),
|
||||
ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(self.module),
|
||||
ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(
|
||||
self.module,
|
||||
self.global_expression_kind_tracker,
|
||||
false,
|
||||
),
|
||||
ExpressionContextType::Override => ConstantEvaluator::for_wgsl_module(
|
||||
self.module,
|
||||
self.global_expression_kind_tracker,
|
||||
true,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -375,20 +413,25 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
|
||||
.ok()
|
||||
}
|
||||
ExpressionContextType::Constant => self.module.to_ctx().eval_expr_to_u32(handle).ok(),
|
||||
ExpressionContextType::Override => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn get_expression_span(&self, handle: Handle<crate::Expression>) -> Span {
|
||||
match self.expr_type {
|
||||
ExpressionContextType::Runtime(ref ctx) => ctx.function.expressions.get_span(handle),
|
||||
ExpressionContextType::Constant => self.module.const_expressions.get_span(handle),
|
||||
ExpressionContextType::Constant | ExpressionContextType::Override => {
|
||||
self.module.const_expressions.get_span(handle)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn typifier(&self) -> &Typifier {
|
||||
match self.expr_type {
|
||||
ExpressionContextType::Runtime(ref ctx) => ctx.typifier,
|
||||
ExpressionContextType::Constant => self.const_typifier,
|
||||
ExpressionContextType::Constant | ExpressionContextType::Override => {
|
||||
self.const_typifier
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -398,7 +441,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
|
||||
) -> Result<&mut RuntimeExpressionContext<'temp, 'out>, Error<'source>> {
|
||||
match self.expr_type {
|
||||
ExpressionContextType::Runtime(ref mut ctx) => Ok(ctx),
|
||||
ExpressionContextType::Constant => Err(Error::UnexpectedOperationInConstContext(span)),
|
||||
ExpressionContextType::Constant | ExpressionContextType::Override => {
|
||||
Err(Error::UnexpectedOperationInConstContext(span))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -435,7 +480,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
|
||||
}
|
||||
// This means a `gather` operation appeared in a constant expression.
|
||||
// This error refers to the `gather` itself, not its "component" argument.
|
||||
ExpressionContextType::Constant => {
|
||||
ExpressionContextType::Constant | ExpressionContextType::Override => {
|
||||
Err(Error::UnexpectedOperationInConstContext(gather_span))
|
||||
}
|
||||
}
|
||||
@@ -461,7 +506,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
|
||||
// to also borrow self.module.types mutably below.
|
||||
let typifier = match self.expr_type {
|
||||
ExpressionContextType::Runtime(ref ctx) => ctx.typifier,
|
||||
ExpressionContextType::Constant => &*self.const_typifier,
|
||||
ExpressionContextType::Constant | ExpressionContextType::Override => {
|
||||
&*self.const_typifier
|
||||
}
|
||||
};
|
||||
Ok(typifier.register_type(handle, &mut self.module.types))
|
||||
}
|
||||
@@ -504,7 +551,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
|
||||
typifier = &mut *ctx.typifier;
|
||||
expressions = &ctx.function.expressions;
|
||||
}
|
||||
ExpressionContextType::Constant => {
|
||||
ExpressionContextType::Constant | ExpressionContextType::Override => {
|
||||
resolve_ctx = ResolveContext::with_locals(self.module, &empty_arena, &[]);
|
||||
typifier = self.const_typifier;
|
||||
expressions = &self.module.const_expressions;
|
||||
@@ -600,14 +647,14 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
|
||||
rctx.block
|
||||
.extend(rctx.emitter.finish(&rctx.function.expressions));
|
||||
}
|
||||
ExpressionContextType::Constant => {}
|
||||
ExpressionContextType::Constant | ExpressionContextType::Override => {}
|
||||
}
|
||||
let result = self.append_expression(expression, span);
|
||||
match self.expr_type {
|
||||
ExpressionContextType::Runtime(ref mut rctx) => {
|
||||
rctx.emitter.start(&rctx.function.expressions);
|
||||
}
|
||||
ExpressionContextType::Constant => {}
|
||||
ExpressionContextType::Constant | ExpressionContextType::Override => {}
|
||||
}
|
||||
result
|
||||
}
|
||||
@@ -852,6 +899,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
|
||||
types: &tu.types,
|
||||
module: &mut module,
|
||||
const_typifier: &mut Typifier::new(),
|
||||
global_expression_kind_tracker: &mut crate::proc::ExpressionConstnessTracker::new(),
|
||||
};
|
||||
|
||||
for decl_handle in self.index.visit_ordered() {
|
||||
@@ -959,7 +1007,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
|
||||
ast::GlobalDeclKind::Override(ref o) => {
|
||||
let init = o
|
||||
.init
|
||||
.map(|init| self.expression(init, &mut ctx.as_const()))
|
||||
.map(|init| self.expression(init, &mut ctx.as_override()))
|
||||
.transpose()?;
|
||||
let inferred_type = init
|
||||
.map(|init| ctx.as_const().register_type(init))
|
||||
@@ -1049,6 +1097,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
|
||||
let mut local_table = FastHashMap::default();
|
||||
let mut expressions = Arena::new();
|
||||
let mut named_expressions = FastIndexMap::default();
|
||||
let mut local_expression_kind_tracker = crate::proc::ExpressionConstnessTracker::new();
|
||||
|
||||
let arguments = f
|
||||
.arguments
|
||||
@@ -1060,6 +1109,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
|
||||
.append(crate::Expression::FunctionArgument(i as u32), arg.name.span);
|
||||
local_table.insert(arg.handle, Typed::Plain(expr));
|
||||
named_expressions.insert(expr, (arg.name.name.to_string(), arg.name.span));
|
||||
local_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Runtime);
|
||||
|
||||
Ok(crate::FunctionArgument {
|
||||
name: Some(arg.name.name.to_string()),
|
||||
@@ -1102,7 +1152,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
|
||||
named_expressions: &mut named_expressions,
|
||||
types: ctx.types,
|
||||
module: ctx.module,
|
||||
expression_constness: &mut crate::proc::ExpressionConstnessTracker::new(),
|
||||
expression_constness: &mut local_expression_kind_tracker,
|
||||
global_expression_kind_tracker: ctx.global_expression_kind_tracker,
|
||||
};
|
||||
let mut body = self.block(&f.body, false, &mut stmt_ctx)?;
|
||||
ensure_block_returns(&mut body);
|
||||
@@ -1518,6 +1569,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
|
||||
.function
|
||||
.expressions
|
||||
.append(crate::Expression::Binary { op, left, right }, stmt.span);
|
||||
rctx.expression_constness
|
||||
.insert(left, crate::proc::ExpressionKind::Runtime);
|
||||
rctx.expression_constness
|
||||
.insert(value, crate::proc::ExpressionKind::Runtime);
|
||||
|
||||
block.extend(emitter.finish(&ctx.function.expressions));
|
||||
crate::Statement::Store {
|
||||
@@ -1611,7 +1666,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
|
||||
LoweredGlobalDecl::Const(handle) => {
|
||||
Typed::Plain(crate::Expression::Constant(handle))
|
||||
}
|
||||
_ => {
|
||||
LoweredGlobalDecl::Override(handle) => {
|
||||
Typed::Plain(crate::Expression::Override(handle))
|
||||
}
|
||||
LoweredGlobalDecl::Function(_)
|
||||
| LoweredGlobalDecl::Type(_)
|
||||
| LoweredGlobalDecl::EntryPoint => {
|
||||
return Err(Error::Unexpected(span, ExpectedToken::Variable));
|
||||
}
|
||||
};
|
||||
@@ -1886,9 +1946,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
|
||||
rctx.block
|
||||
.extend(rctx.emitter.finish(&rctx.function.expressions));
|
||||
let result = has_result.then(|| {
|
||||
rctx.function
|
||||
let result = rctx
|
||||
.function
|
||||
.expressions
|
||||
.append(crate::Expression::CallResult(function), span)
|
||||
.append(crate::Expression::CallResult(function), span);
|
||||
rctx.expression_constness
|
||||
.insert(result, crate::proc::ExpressionKind::Runtime);
|
||||
result
|
||||
});
|
||||
rctx.emitter.start(&rctx.function.expressions);
|
||||
rctx.block.push(
|
||||
|
||||
@@ -253,9 +253,9 @@ gen_component_wise_extractor! {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Behavior {
|
||||
Wgsl,
|
||||
Glsl,
|
||||
enum Behavior<'a> {
|
||||
Wgsl(WgslRestrictions<'a>),
|
||||
Glsl(GlslRestrictions<'a>),
|
||||
}
|
||||
|
||||
/// A context for evaluating constant expressions.
|
||||
@@ -278,7 +278,7 @@ enum Behavior {
|
||||
#[derive(Debug)]
|
||||
pub struct ConstantEvaluator<'a> {
|
||||
/// Which language's evaluation rules we should follow.
|
||||
behavior: Behavior,
|
||||
behavior: Behavior<'a>,
|
||||
|
||||
/// The module's type arena.
|
||||
///
|
||||
@@ -297,65 +297,145 @@ pub struct ConstantEvaluator<'a> {
|
||||
/// The arena to which we are contributing expressions.
|
||||
expressions: &'a mut Arena<Expression>,
|
||||
|
||||
/// When `self.expressions` refers to a function's local expression
|
||||
/// arena, this needs to be populated
|
||||
function_local_data: Option<FunctionLocalData<'a>>,
|
||||
/// Tracks the constness of expressions residing in [`Self::expressions`]
|
||||
expression_kind_tracker: &'a mut ExpressionConstnessTracker,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum WgslRestrictions<'a> {
|
||||
/// - const-expressions will be evaluated and inserted in the arena
|
||||
Const,
|
||||
/// - const-expressions will be evaluated and inserted in the arena
|
||||
/// - override-expressions will be inserted in the arena
|
||||
Override,
|
||||
/// - const-expressions will be evaluated and inserted in the arena
|
||||
/// - override-expressions will be inserted in the arena
|
||||
/// - runtime-expressions will be inserted in the arena
|
||||
Runtime(FunctionLocalData<'a>),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum GlslRestrictions<'a> {
|
||||
/// - const-expressions will be evaluated and inserted in the arena
|
||||
Const,
|
||||
/// - const-expressions will be evaluated and inserted in the arena
|
||||
/// - override-expressions will be inserted in the arena
|
||||
/// - runtime-expressions will be inserted in the arena
|
||||
Runtime(FunctionLocalData<'a>),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct FunctionLocalData<'a> {
|
||||
/// Global constant expressions
|
||||
const_expressions: &'a Arena<Expression>,
|
||||
/// Tracks the constness of expressions residing in `ConstantEvaluator.expressions`
|
||||
expression_constness: &'a mut ExpressionConstnessTracker,
|
||||
emitter: &'a mut super::Emitter,
|
||||
block: &'a mut crate::Block,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
|
||||
pub enum ExpressionKind {
|
||||
Const,
|
||||
Override,
|
||||
Runtime,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ExpressionConstnessTracker {
|
||||
inner: bit_set::BitSet,
|
||||
inner: Vec<ExpressionKind>,
|
||||
}
|
||||
|
||||
impl ExpressionConstnessTracker {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inner: bit_set::BitSet::new(),
|
||||
}
|
||||
pub const fn new() -> Self {
|
||||
Self { inner: Vec::new() }
|
||||
}
|
||||
|
||||
/// Forces the the expression to not be const
|
||||
pub fn force_non_const(&mut self, value: Handle<Expression>) {
|
||||
self.inner.remove(value.index());
|
||||
self.inner[value.index()] = ExpressionKind::Runtime;
|
||||
}
|
||||
|
||||
fn insert(&mut self, value: Handle<Expression>) {
|
||||
self.inner.insert(value.index());
|
||||
pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
|
||||
assert_eq!(self.inner.len(), value.index());
|
||||
self.inner.push(expr_type);
|
||||
}
|
||||
pub fn is_const(&self, h: Handle<Expression>) -> bool {
|
||||
matches!(self.type_of(h), ExpressionKind::Const)
|
||||
}
|
||||
|
||||
pub fn is_const(&self, value: Handle<Expression>) -> bool {
|
||||
self.inner.contains(value.index())
|
||||
pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
|
||||
matches!(
|
||||
self.type_of(h),
|
||||
ExpressionKind::Const | ExpressionKind::Override
|
||||
)
|
||||
}
|
||||
|
||||
fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
|
||||
self.inner[value.index()]
|
||||
}
|
||||
|
||||
pub fn from_arena(arena: &Arena<Expression>) -> Self {
|
||||
let mut tracker = Self::new();
|
||||
for (handle, expr) in arena.iter() {
|
||||
let insert = match *expr {
|
||||
crate::Expression::Literal(_)
|
||||
| crate::Expression::ZeroValue(_)
|
||||
| crate::Expression::Constant(_) => true,
|
||||
crate::Expression::Compose { ref components, .. } => {
|
||||
components.iter().all(|h| tracker.is_const(*h))
|
||||
}
|
||||
crate::Expression::Splat { value, .. } => tracker.is_const(value),
|
||||
_ => false,
|
||||
};
|
||||
if insert {
|
||||
tracker.insert(handle);
|
||||
}
|
||||
let mut tracker = Self {
|
||||
inner: Vec::with_capacity(arena.len()),
|
||||
};
|
||||
for (_, expr) in arena.iter() {
|
||||
tracker.inner.push(tracker.type_of_with_expr(expr));
|
||||
}
|
||||
tracker
|
||||
}
|
||||
|
||||
fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
|
||||
match *expr {
|
||||
Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
|
||||
ExpressionKind::Const
|
||||
}
|
||||
Expression::Override(_) => ExpressionKind::Override,
|
||||
Expression::Compose { ref components, .. } => {
|
||||
let mut expr_type = ExpressionKind::Const;
|
||||
for component in components {
|
||||
expr_type = expr_type.max(self.type_of(*component))
|
||||
}
|
||||
expr_type
|
||||
}
|
||||
Expression::Splat { value, .. } => self.type_of(value),
|
||||
Expression::AccessIndex { base, .. } => self.type_of(base),
|
||||
Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
|
||||
Expression::Swizzle { vector, .. } => self.type_of(vector),
|
||||
Expression::Unary { expr, .. } => self.type_of(expr),
|
||||
Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
|
||||
Expression::Math {
|
||||
arg,
|
||||
arg1,
|
||||
arg2,
|
||||
arg3,
|
||||
..
|
||||
} => self
|
||||
.type_of(arg)
|
||||
.max(
|
||||
arg1.map(|arg| self.type_of(arg))
|
||||
.unwrap_or(ExpressionKind::Const),
|
||||
)
|
||||
.max(
|
||||
arg2.map(|arg| self.type_of(arg))
|
||||
.unwrap_or(ExpressionKind::Const),
|
||||
)
|
||||
.max(
|
||||
arg3.map(|arg| self.type_of(arg))
|
||||
.unwrap_or(ExpressionKind::Const),
|
||||
),
|
||||
Expression::As { expr, .. } => self.type_of(expr),
|
||||
Expression::Select {
|
||||
condition,
|
||||
accept,
|
||||
reject,
|
||||
} => self
|
||||
.type_of(condition)
|
||||
.max(self.type_of(accept))
|
||||
.max(self.type_of(reject)),
|
||||
Expression::Relational { argument, .. } => self.type_of(argument),
|
||||
Expression::ArrayLength(expr) => self.type_of(expr),
|
||||
_ => ExpressionKind::Runtime,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, thiserror::Error)]
|
||||
@@ -436,6 +516,12 @@ pub enum ConstantEvaluatorError {
|
||||
ShiftedMoreThan32Bits,
|
||||
#[error(transparent)]
|
||||
Literal(#[from] crate::valid::LiteralError),
|
||||
#[error("Can't use pipeline-overridable constants in const-expressions")]
|
||||
Override,
|
||||
#[error("Unexpected runtime-expression")]
|
||||
RuntimeExpr,
|
||||
#[error("Unexpected override-expression")]
|
||||
OverrideExpr,
|
||||
}
|
||||
|
||||
impl<'a> ConstantEvaluator<'a> {
|
||||
@@ -443,26 +529,49 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
/// constant expression arena.
|
||||
///
|
||||
/// Report errors according to WGSL's rules for constant evaluation.
|
||||
pub fn for_wgsl_module(module: &'a mut crate::Module) -> Self {
|
||||
Self::for_module(Behavior::Wgsl, module)
|
||||
pub fn for_wgsl_module(
|
||||
module: &'a mut crate::Module,
|
||||
global_expression_kind_tracker: &'a mut ExpressionConstnessTracker,
|
||||
in_override_ctx: bool,
|
||||
) -> Self {
|
||||
Self::for_module(
|
||||
Behavior::Wgsl(if in_override_ctx {
|
||||
WgslRestrictions::Override
|
||||
} else {
|
||||
WgslRestrictions::Const
|
||||
}),
|
||||
module,
|
||||
global_expression_kind_tracker,
|
||||
)
|
||||
}
|
||||
|
||||
/// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
|
||||
/// constant expression arena.
|
||||
///
|
||||
/// Report errors according to GLSL's rules for constant evaluation.
|
||||
pub fn for_glsl_module(module: &'a mut crate::Module) -> Self {
|
||||
Self::for_module(Behavior::Glsl, module)
|
||||
pub fn for_glsl_module(
|
||||
module: &'a mut crate::Module,
|
||||
global_expression_kind_tracker: &'a mut ExpressionConstnessTracker,
|
||||
) -> Self {
|
||||
Self::for_module(
|
||||
Behavior::Glsl(GlslRestrictions::Const),
|
||||
module,
|
||||
global_expression_kind_tracker,
|
||||
)
|
||||
}
|
||||
|
||||
fn for_module(behavior: Behavior, module: &'a mut crate::Module) -> Self {
|
||||
fn for_module(
|
||||
behavior: Behavior<'a>,
|
||||
module: &'a mut crate::Module,
|
||||
global_expression_kind_tracker: &'a mut ExpressionConstnessTracker,
|
||||
) -> Self {
|
||||
Self {
|
||||
behavior,
|
||||
types: &mut module.types,
|
||||
constants: &module.constants,
|
||||
overrides: &module.overrides,
|
||||
expressions: &mut module.const_expressions,
|
||||
function_local_data: None,
|
||||
expression_kind_tracker: global_expression_kind_tracker,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -473,18 +582,22 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
pub fn for_wgsl_function(
|
||||
module: &'a mut crate::Module,
|
||||
expressions: &'a mut Arena<Expression>,
|
||||
expression_constness: &'a mut ExpressionConstnessTracker,
|
||||
local_expression_kind_tracker: &'a mut ExpressionConstnessTracker,
|
||||
emitter: &'a mut super::Emitter,
|
||||
block: &'a mut crate::Block,
|
||||
) -> Self {
|
||||
Self::for_function(
|
||||
Behavior::Wgsl,
|
||||
module,
|
||||
Self {
|
||||
behavior: Behavior::Wgsl(WgslRestrictions::Runtime(FunctionLocalData {
|
||||
const_expressions: &module.const_expressions,
|
||||
emitter,
|
||||
block,
|
||||
})),
|
||||
types: &mut module.types,
|
||||
constants: &module.constants,
|
||||
overrides: &module.overrides,
|
||||
expressions,
|
||||
expression_constness,
|
||||
emitter,
|
||||
block,
|
||||
)
|
||||
expression_kind_tracker: local_expression_kind_tracker,
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
|
||||
@@ -494,40 +607,21 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
pub fn for_glsl_function(
|
||||
module: &'a mut crate::Module,
|
||||
expressions: &'a mut Arena<Expression>,
|
||||
expression_constness: &'a mut ExpressionConstnessTracker,
|
||||
emitter: &'a mut super::Emitter,
|
||||
block: &'a mut crate::Block,
|
||||
) -> Self {
|
||||
Self::for_function(
|
||||
Behavior::Glsl,
|
||||
module,
|
||||
expressions,
|
||||
expression_constness,
|
||||
emitter,
|
||||
block,
|
||||
)
|
||||
}
|
||||
|
||||
fn for_function(
|
||||
behavior: Behavior,
|
||||
module: &'a mut crate::Module,
|
||||
expressions: &'a mut Arena<Expression>,
|
||||
expression_constness: &'a mut ExpressionConstnessTracker,
|
||||
local_expression_kind_tracker: &'a mut ExpressionConstnessTracker,
|
||||
emitter: &'a mut super::Emitter,
|
||||
block: &'a mut crate::Block,
|
||||
) -> Self {
|
||||
Self {
|
||||
behavior,
|
||||
behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
|
||||
const_expressions: &module.const_expressions,
|
||||
emitter,
|
||||
block,
|
||||
})),
|
||||
types: &mut module.types,
|
||||
constants: &module.constants,
|
||||
overrides: &module.overrides,
|
||||
expressions,
|
||||
function_local_data: Some(FunctionLocalData {
|
||||
const_expressions: &module.const_expressions,
|
||||
expression_constness,
|
||||
emitter,
|
||||
block,
|
||||
}),
|
||||
expression_kind_tracker: local_expression_kind_tracker,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -536,19 +630,17 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
types: self.types,
|
||||
constants: self.constants,
|
||||
overrides: self.overrides,
|
||||
const_expressions: match self.function_local_data {
|
||||
Some(ref data) => data.const_expressions,
|
||||
const_expressions: match self.function_local_data() {
|
||||
Some(data) => data.const_expressions,
|
||||
None => self.expressions,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
|
||||
if let Some(ref function_local_data) = self.function_local_data {
|
||||
if !function_local_data.expression_constness.is_const(expr) {
|
||||
log::debug!("check: SubexpressionsAreNotConstant");
|
||||
return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
|
||||
}
|
||||
if !self.expression_kind_tracker.is_const(expr) {
|
||||
log::debug!("check: SubexpressionsAreNotConstant");
|
||||
return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -561,7 +653,7 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
Expression::Constant(c) => {
|
||||
// Are we working in a function's expression arena, or the
|
||||
// module's constant expression arena?
|
||||
if let Some(ref function_local_data) = self.function_local_data {
|
||||
if let Some(function_local_data) = self.function_local_data() {
|
||||
// Deep-copy the constant's value into our arena.
|
||||
self.copy_from(
|
||||
self.constants[c].init,
|
||||
@@ -607,14 +699,56 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
expr: Expression,
|
||||
span: Span,
|
||||
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
|
||||
let res = self.try_eval_and_append_impl(&expr, span);
|
||||
if self.function_local_data.is_some() {
|
||||
match res {
|
||||
Ok(h) => Ok(h),
|
||||
Err(_) => Ok(self.append_expr(expr, span, false)),
|
||||
match (
|
||||
&self.behavior,
|
||||
self.expression_kind_tracker.type_of_with_expr(&expr),
|
||||
) {
|
||||
// avoid errors on unimplemented functionality if possible
|
||||
(
|
||||
&Behavior::Wgsl(WgslRestrictions::Runtime(_))
|
||||
| &Behavior::Glsl(GlslRestrictions::Runtime(_)),
|
||||
ExpressionKind::Const,
|
||||
) => match self.try_eval_and_append_impl(&expr, span) {
|
||||
Err(
|
||||
ConstantEvaluatorError::NotImplemented(_)
|
||||
| ConstantEvaluatorError::InvalidBinaryOpArgs,
|
||||
) => Ok(self.append_expr(expr, span, ExpressionKind::Runtime)),
|
||||
res => res,
|
||||
},
|
||||
(_, ExpressionKind::Const) => self.try_eval_and_append_impl(&expr, span),
|
||||
(&Behavior::Wgsl(WgslRestrictions::Const), ExpressionKind::Override) => {
|
||||
Err(ConstantEvaluatorError::OverrideExpr)
|
||||
}
|
||||
} else {
|
||||
res
|
||||
(
|
||||
&Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)),
|
||||
ExpressionKind::Override,
|
||||
) => Ok(self.append_expr(expr, span, ExpressionKind::Override)),
|
||||
(&Behavior::Glsl(_), ExpressionKind::Override) => unreachable!(),
|
||||
(
|
||||
&Behavior::Wgsl(WgslRestrictions::Runtime(_))
|
||||
| &Behavior::Glsl(GlslRestrictions::Runtime(_)),
|
||||
ExpressionKind::Runtime,
|
||||
) => Ok(self.append_expr(expr, span, ExpressionKind::Runtime)),
|
||||
(_, ExpressionKind::Runtime) => Err(ConstantEvaluatorError::RuntimeExpr),
|
||||
}
|
||||
}
|
||||
|
||||
/// Is the [`Self::expressions`] arena the global module expression arena?
|
||||
const fn is_global_arena(&self) -> bool {
|
||||
matches!(
|
||||
self.behavior,
|
||||
Behavior::Wgsl(WgslRestrictions::Const | WgslRestrictions::Override)
|
||||
| Behavior::Glsl(GlslRestrictions::Const)
|
||||
)
|
||||
}
|
||||
|
||||
const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
|
||||
match self.behavior {
|
||||
Behavior::Wgsl(WgslRestrictions::Runtime(ref function_local_data))
|
||||
| Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
|
||||
Some(function_local_data)
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -625,14 +759,12 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
|
||||
log::trace!("try_eval_and_append: {:?}", expr);
|
||||
match *expr {
|
||||
Expression::Constant(c) if self.function_local_data.is_none() => {
|
||||
Expression::Constant(c) if self.is_global_arena() => {
|
||||
// "See through" the constant and use its initializer.
|
||||
// This is mainly done to avoid having constants pointing to other constants.
|
||||
Ok(self.constants[c].init)
|
||||
}
|
||||
Expression::Override(_) => Err(ConstantEvaluatorError::NotImplemented(
|
||||
"overrides are WIP".into(),
|
||||
)),
|
||||
Expression::Override(_) => Err(ConstantEvaluatorError::Override),
|
||||
Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
|
||||
self.register_evaluated_expr(expr.clone(), span)
|
||||
}
|
||||
@@ -713,8 +845,8 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
format!("{fun:?} built-in function"),
|
||||
)),
|
||||
Expression::ArrayLength(expr) => match self.behavior {
|
||||
Behavior::Wgsl => Err(ConstantEvaluatorError::ArrayLength),
|
||||
Behavior::Glsl => {
|
||||
Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
|
||||
Behavior::Glsl(_) => {
|
||||
let expr = self.check_and_get(expr)?;
|
||||
self.array_length(expr, span)
|
||||
}
|
||||
@@ -1881,34 +2013,35 @@ impl<'a> ConstantEvaluator<'a> {
|
||||
crate::valid::check_literal_value(literal)?;
|
||||
}
|
||||
|
||||
Ok(self.append_expr(expr, span, true))
|
||||
Ok(self.append_expr(expr, span, ExpressionKind::Const))
|
||||
}
|
||||
|
||||
fn append_expr(&mut self, expr: Expression, span: Span, is_const: bool) -> Handle<Expression> {
|
||||
if let Some(FunctionLocalData {
|
||||
ref mut emitter,
|
||||
ref mut block,
|
||||
ref mut expression_constness,
|
||||
..
|
||||
}) = self.function_local_data
|
||||
{
|
||||
let is_running = emitter.is_running();
|
||||
let needs_pre_emit = expr.needs_pre_emit();
|
||||
let h = if is_running && needs_pre_emit {
|
||||
block.extend(emitter.finish(self.expressions));
|
||||
let h = self.expressions.append(expr, span);
|
||||
emitter.start(self.expressions);
|
||||
h
|
||||
} else {
|
||||
self.expressions.append(expr, span)
|
||||
};
|
||||
if is_const {
|
||||
expression_constness.insert(h);
|
||||
fn append_expr(
|
||||
&mut self,
|
||||
expr: Expression,
|
||||
span: Span,
|
||||
expr_type: ExpressionKind,
|
||||
) -> Handle<Expression> {
|
||||
let h = match self.behavior {
|
||||
Behavior::Wgsl(WgslRestrictions::Runtime(ref mut function_local_data))
|
||||
| Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
|
||||
let is_running = function_local_data.emitter.is_running();
|
||||
let needs_pre_emit = expr.needs_pre_emit();
|
||||
if is_running && needs_pre_emit {
|
||||
function_local_data
|
||||
.block
|
||||
.extend(function_local_data.emitter.finish(self.expressions));
|
||||
let h = self.expressions.append(expr, span);
|
||||
function_local_data.emitter.start(self.expressions);
|
||||
h
|
||||
} else {
|
||||
self.expressions.append(expr, span)
|
||||
}
|
||||
}
|
||||
h
|
||||
} else {
|
||||
self.expressions.append(expr, span)
|
||||
}
|
||||
_ => self.expressions.append(expr, span),
|
||||
};
|
||||
self.expression_kind_tracker.insert(h, expr_type);
|
||||
h
|
||||
}
|
||||
|
||||
fn resolve_type(
|
||||
@@ -2062,7 +2195,7 @@ mod tests {
|
||||
UniqueArena, VectorSize,
|
||||
};
|
||||
|
||||
use super::{Behavior, ConstantEvaluator};
|
||||
use super::{Behavior, ConstantEvaluator, ExpressionConstnessTracker, WgslRestrictions};
|
||||
|
||||
#[test]
|
||||
fn unary_op() {
|
||||
@@ -2143,13 +2276,15 @@ mod tests {
|
||||
expr: expr1,
|
||||
};
|
||||
|
||||
let expression_kind_tracker =
|
||||
&mut ExpressionConstnessTracker::from_arena(&const_expressions);
|
||||
let mut solver = ConstantEvaluator {
|
||||
behavior: Behavior::Wgsl,
|
||||
behavior: Behavior::Wgsl(WgslRestrictions::Const),
|
||||
types: &mut types,
|
||||
constants: &constants,
|
||||
overrides: &overrides,
|
||||
expressions: &mut const_expressions,
|
||||
function_local_data: None,
|
||||
expression_kind_tracker,
|
||||
};
|
||||
|
||||
let res1 = solver
|
||||
@@ -2228,13 +2363,15 @@ mod tests {
|
||||
convert: Some(crate::BOOL_WIDTH),
|
||||
};
|
||||
|
||||
let expression_kind_tracker =
|
||||
&mut ExpressionConstnessTracker::from_arena(&const_expressions);
|
||||
let mut solver = ConstantEvaluator {
|
||||
behavior: Behavior::Wgsl,
|
||||
behavior: Behavior::Wgsl(WgslRestrictions::Const),
|
||||
types: &mut types,
|
||||
constants: &constants,
|
||||
overrides: &overrides,
|
||||
expressions: &mut const_expressions,
|
||||
function_local_data: None,
|
||||
expression_kind_tracker,
|
||||
};
|
||||
|
||||
let res = solver
|
||||
@@ -2345,13 +2482,15 @@ mod tests {
|
||||
|
||||
let base = const_expressions.append(Expression::Constant(h), Default::default());
|
||||
|
||||
let expression_kind_tracker =
|
||||
&mut ExpressionConstnessTracker::from_arena(&const_expressions);
|
||||
let mut solver = ConstantEvaluator {
|
||||
behavior: Behavior::Wgsl,
|
||||
behavior: Behavior::Wgsl(WgslRestrictions::Const),
|
||||
types: &mut types,
|
||||
constants: &constants,
|
||||
overrides: &overrides,
|
||||
expressions: &mut const_expressions,
|
||||
function_local_data: None,
|
||||
expression_kind_tracker,
|
||||
};
|
||||
|
||||
let root1 = Expression::AccessIndex { base, index: 1 };
|
||||
@@ -2437,13 +2576,15 @@ mod tests {
|
||||
|
||||
let h_expr = const_expressions.append(Expression::Constant(h), Default::default());
|
||||
|
||||
let expression_kind_tracker =
|
||||
&mut ExpressionConstnessTracker::from_arena(&const_expressions);
|
||||
let mut solver = ConstantEvaluator {
|
||||
behavior: Behavior::Wgsl,
|
||||
behavior: Behavior::Wgsl(WgslRestrictions::Const),
|
||||
types: &mut types,
|
||||
constants: &constants,
|
||||
overrides: &overrides,
|
||||
expressions: &mut const_expressions,
|
||||
function_local_data: None,
|
||||
expression_kind_tracker,
|
||||
};
|
||||
|
||||
let solved_compose = solver
|
||||
@@ -2518,13 +2659,15 @@ mod tests {
|
||||
|
||||
let h_expr = const_expressions.append(Expression::Constant(h), Default::default());
|
||||
|
||||
let expression_kind_tracker =
|
||||
&mut ExpressionConstnessTracker::from_arena(&const_expressions);
|
||||
let mut solver = ConstantEvaluator {
|
||||
behavior: Behavior::Wgsl,
|
||||
behavior: Behavior::Wgsl(WgslRestrictions::Const),
|
||||
types: &mut types,
|
||||
constants: &constants,
|
||||
overrides: &overrides,
|
||||
expressions: &mut const_expressions,
|
||||
function_local_data: None,
|
||||
expression_kind_tracker,
|
||||
};
|
||||
|
||||
let solved_compose = solver
|
||||
|
||||
@@ -11,7 +11,7 @@ mod terminator;
|
||||
mod typifier;
|
||||
|
||||
pub use constant_evaluator::{
|
||||
ConstantEvaluator, ConstantEvaluatorError, ExpressionConstnessTracker,
|
||||
ConstantEvaluator, ConstantEvaluatorError, ExpressionConstnessTracker, ExpressionKind,
|
||||
};
|
||||
pub use emitter::Emitter;
|
||||
pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
|
||||
|
||||
@@ -226,7 +226,7 @@ struct Sampling {
|
||||
sampler: GlobalOrArgument,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
|
||||
pub struct FunctionInfo {
|
||||
|
||||
@@ -90,6 +90,8 @@ pub enum ExpressionError {
|
||||
sampler: bool,
|
||||
has_ref: bool,
|
||||
},
|
||||
#[error("Sample offset must be a const-expression")]
|
||||
InvalidSampleOffsetExprType,
|
||||
#[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")]
|
||||
InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>),
|
||||
#[error("Depth reference {0:?} is not a scalar float")]
|
||||
@@ -129,9 +131,10 @@ pub enum ExpressionError {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, thiserror::Error)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
pub enum ConstExpressionError {
|
||||
#[error("The expression is not a constant expression")]
|
||||
NonConst,
|
||||
#[error("The expression is not a constant or override expression")]
|
||||
NonConstOrOverride,
|
||||
#[error(transparent)]
|
||||
Compose(#[from] super::ComposeError),
|
||||
#[error("Splatting {0:?} can't be done")]
|
||||
@@ -184,9 +187,14 @@ impl super::Validator {
|
||||
handle: Handle<crate::Expression>,
|
||||
gctx: crate::proc::GlobalCtx,
|
||||
mod_info: &ModuleInfo,
|
||||
global_expr_kind: &crate::proc::ExpressionConstnessTracker,
|
||||
) -> Result<(), ConstExpressionError> {
|
||||
use crate::Expression as E;
|
||||
|
||||
if !global_expr_kind.is_const_or_override(handle) {
|
||||
return Err(super::ConstExpressionError::NonConstOrOverride);
|
||||
}
|
||||
|
||||
match gctx.const_expressions[handle] {
|
||||
E::Literal(literal) => {
|
||||
self.validate_literal(literal)?;
|
||||
@@ -203,12 +211,14 @@ impl super::Validator {
|
||||
crate::TypeInner::Scalar { .. } => {}
|
||||
_ => return Err(super::ConstExpressionError::InvalidSplatType(value)),
|
||||
},
|
||||
_ => return Err(super::ConstExpressionError::NonConst),
|
||||
// the constant evaluator will report errors about override-expressions
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(super) fn validate_expression(
|
||||
&self,
|
||||
root: Handle<crate::Expression>,
|
||||
@@ -217,6 +227,7 @@ impl super::Validator {
|
||||
module: &crate::Module,
|
||||
info: &FunctionInfo,
|
||||
mod_info: &ModuleInfo,
|
||||
global_expr_kind: &crate::proc::ExpressionConstnessTracker,
|
||||
) -> Result<ShaderStages, ExpressionError> {
|
||||
use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti};
|
||||
|
||||
@@ -462,6 +473,10 @@ impl super::Validator {
|
||||
|
||||
// check constant offset
|
||||
if let Some(const_expr) = offset {
|
||||
if !global_expr_kind.is_const(const_expr) {
|
||||
return Err(ExpressionError::InvalidSampleOffsetExprType);
|
||||
}
|
||||
|
||||
match *mod_info[const_expr].inner_with(&module.types) {
|
||||
Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {}
|
||||
Ti::Vector {
|
||||
|
||||
@@ -927,7 +927,7 @@ impl super::Validator {
|
||||
var: &crate::LocalVariable,
|
||||
gctx: crate::proc::GlobalCtx,
|
||||
fun_info: &FunctionInfo,
|
||||
expression_constness: &crate::proc::ExpressionConstnessTracker,
|
||||
local_expr_kind: &crate::proc::ExpressionConstnessTracker,
|
||||
) -> Result<(), LocalVariableError> {
|
||||
log::debug!("var {:?}", var);
|
||||
let type_info = self
|
||||
@@ -945,7 +945,7 @@ impl super::Validator {
|
||||
return Err(LocalVariableError::InitializerType);
|
||||
}
|
||||
|
||||
if !expression_constness.is_const(init) {
|
||||
if !local_expr_kind.is_const(init) {
|
||||
return Err(LocalVariableError::NonConstInitializer);
|
||||
}
|
||||
}
|
||||
@@ -959,14 +959,14 @@ impl super::Validator {
|
||||
module: &crate::Module,
|
||||
mod_info: &ModuleInfo,
|
||||
entry_point: bool,
|
||||
global_expr_kind: &crate::proc::ExpressionConstnessTracker,
|
||||
) -> Result<FunctionInfo, WithSpan<FunctionError>> {
|
||||
let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?;
|
||||
|
||||
let expression_constness =
|
||||
crate::proc::ExpressionConstnessTracker::from_arena(&fun.expressions);
|
||||
let local_expr_kind = crate::proc::ExpressionConstnessTracker::from_arena(&fun.expressions);
|
||||
|
||||
for (var_handle, var) in fun.local_variables.iter() {
|
||||
self.validate_local_var(var, module.to_ctx(), &info, &expression_constness)
|
||||
self.validate_local_var(var, module.to_ctx(), &info, &local_expr_kind)
|
||||
.map_err(|source| {
|
||||
FunctionError::LocalVariable {
|
||||
handle: var_handle,
|
||||
@@ -1032,7 +1032,15 @@ impl super::Validator {
|
||||
self.valid_expression_set.insert(handle.index());
|
||||
}
|
||||
if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
|
||||
match self.validate_expression(handle, expr, fun, module, &info, mod_info) {
|
||||
match self.validate_expression(
|
||||
handle,
|
||||
expr,
|
||||
fun,
|
||||
module,
|
||||
&info,
|
||||
mod_info,
|
||||
global_expr_kind,
|
||||
) {
|
||||
Ok(stages) => info.available_stages &= stages,
|
||||
Err(source) => {
|
||||
return Err(FunctionError::Expression { handle, source }
|
||||
|
||||
@@ -592,6 +592,7 @@ impl From<BadRangeError> for ValidationError {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, thiserror::Error)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
pub enum InvalidHandleError {
|
||||
#[error(transparent)]
|
||||
BadHandle(#[from] BadHandle),
|
||||
@@ -602,6 +603,7 @@ pub enum InvalidHandleError {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, thiserror::Error)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
#[error(
|
||||
"{subject:?} of kind {subject_kind:?} depends on {depends_on:?} of kind {depends_on_kind}, \
|
||||
which has not been processed yet"
|
||||
|
||||
@@ -10,6 +10,7 @@ use bit_set::BitSet;
|
||||
const MAX_WORKGROUP_SIZE: u32 = 0x4000;
|
||||
|
||||
#[derive(Clone, Debug, thiserror::Error)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
pub enum GlobalVariableError {
|
||||
#[error("Usage isn't compatible with address space {0:?}")]
|
||||
InvalidUsage(crate::AddressSpace),
|
||||
@@ -30,6 +31,8 @@ pub enum GlobalVariableError {
|
||||
Handle<crate::Type>,
|
||||
#[source] Disalignment,
|
||||
),
|
||||
#[error("Initializer must be a const-expression")]
|
||||
InitializerExprType,
|
||||
#[error("Initializer doesn't match the variable type")]
|
||||
InitializerType,
|
||||
#[error("Initializer can't be used with address space {0:?}")]
|
||||
@@ -39,6 +42,7 @@ pub enum GlobalVariableError {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, thiserror::Error)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
pub enum VaryingError {
|
||||
#[error("The type {0:?} does not match the varying")]
|
||||
InvalidType(Handle<crate::Type>),
|
||||
@@ -76,6 +80,7 @@ pub enum VaryingError {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, thiserror::Error)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
pub enum EntryPointError {
|
||||
#[error("Multiple conflicting entry points")]
|
||||
Conflict,
|
||||
@@ -395,6 +400,7 @@ impl super::Validator {
|
||||
var: &crate::GlobalVariable,
|
||||
gctx: crate::proc::GlobalCtx,
|
||||
mod_info: &ModuleInfo,
|
||||
global_expr_kind: &crate::proc::ExpressionConstnessTracker,
|
||||
) -> Result<(), GlobalVariableError> {
|
||||
use super::TypeFlags;
|
||||
|
||||
@@ -523,6 +529,10 @@ impl super::Validator {
|
||||
}
|
||||
}
|
||||
|
||||
if !global_expr_kind.is_const(init) {
|
||||
return Err(GlobalVariableError::InitializerExprType);
|
||||
}
|
||||
|
||||
let decl_ty = &gctx.types[var.ty].inner;
|
||||
let init_ty = mod_info[init].inner_with(gctx.types);
|
||||
if !decl_ty.equivalent(init_ty, gctx.types) {
|
||||
@@ -538,6 +548,7 @@ impl super::Validator {
|
||||
ep: &crate::EntryPoint,
|
||||
module: &crate::Module,
|
||||
mod_info: &ModuleInfo,
|
||||
global_expr_kind: &crate::proc::ExpressionConstnessTracker,
|
||||
) -> Result<FunctionInfo, WithSpan<EntryPointError>> {
|
||||
if ep.early_depth_test.is_some() {
|
||||
let required = Capabilities::EARLY_DEPTH_TEST;
|
||||
@@ -566,7 +577,7 @@ impl super::Validator {
|
||||
}
|
||||
|
||||
let mut info = self
|
||||
.validate_function(&ep.function, module, mod_info, true)
|
||||
.validate_function(&ep.function, module, mod_info, true, global_expr_kind)
|
||||
.map_err(WithSpan::into_other)?;
|
||||
|
||||
{
|
||||
|
||||
@@ -12,7 +12,7 @@ mod r#type;
|
||||
|
||||
use crate::{
|
||||
arena::Handle,
|
||||
proc::{LayoutError, Layouter, TypeResolution},
|
||||
proc::{ExpressionConstnessTracker, LayoutError, Layouter, TypeResolution},
|
||||
FastHashSet,
|
||||
};
|
||||
use bit_set::BitSet;
|
||||
@@ -131,7 +131,7 @@ bitflags::bitflags! {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
|
||||
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
|
||||
pub struct ModuleInfo {
|
||||
@@ -178,7 +178,10 @@ pub struct Validator {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, thiserror::Error)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
pub enum ConstantError {
|
||||
#[error("Initializer must be a const-expression")]
|
||||
InitializerExprType,
|
||||
#[error("The type doesn't match the constant")]
|
||||
InvalidType,
|
||||
#[error("The type is not constructible")]
|
||||
@@ -186,11 +189,14 @@ pub enum ConstantError {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, thiserror::Error)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
pub enum OverrideError {
|
||||
#[error("Override name and ID are missing")]
|
||||
MissingNameAndID,
|
||||
#[error("Override ID must be unique")]
|
||||
DuplicateID,
|
||||
#[error("Initializer must be a const-expression or override-expression")]
|
||||
InitializerExprType,
|
||||
#[error("The type doesn't match the override")]
|
||||
InvalidType,
|
||||
#[error("The type is not constructible")]
|
||||
@@ -200,6 +206,7 @@ pub enum OverrideError {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, thiserror::Error)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
pub enum ValidationError {
|
||||
#[error(transparent)]
|
||||
InvalidHandle(#[from] InvalidHandleError),
|
||||
@@ -335,6 +342,7 @@ impl Validator {
|
||||
handle: Handle<crate::Constant>,
|
||||
gctx: crate::proc::GlobalCtx,
|
||||
mod_info: &ModuleInfo,
|
||||
global_expr_kind: &ExpressionConstnessTracker,
|
||||
) -> Result<(), ConstantError> {
|
||||
let con = &gctx.constants[handle];
|
||||
|
||||
@@ -343,6 +351,10 @@ impl Validator {
|
||||
return Err(ConstantError::NonConstructibleType);
|
||||
}
|
||||
|
||||
if !global_expr_kind.is_const(con.init) {
|
||||
return Err(ConstantError::InitializerExprType);
|
||||
}
|
||||
|
||||
let decl_ty = &gctx.types[con.ty].inner;
|
||||
let init_ty = mod_info[con.init].inner_with(gctx.types);
|
||||
if !decl_ty.equivalent(init_ty, gctx.types) {
|
||||
@@ -455,17 +467,24 @@ impl Validator {
|
||||
}
|
||||
}
|
||||
|
||||
let global_expr_kind = ExpressionConstnessTracker::from_arena(&module.const_expressions);
|
||||
|
||||
if self.flags.contains(ValidationFlags::CONSTANTS) {
|
||||
for (handle, _) in module.const_expressions.iter() {
|
||||
self.validate_const_expression(handle, module.to_ctx(), &mod_info)
|
||||
.map_err(|source| {
|
||||
ValidationError::ConstExpression { handle, source }
|
||||
.with_span_handle(handle, &module.const_expressions)
|
||||
})?
|
||||
self.validate_const_expression(
|
||||
handle,
|
||||
module.to_ctx(),
|
||||
&mod_info,
|
||||
&global_expr_kind,
|
||||
)
|
||||
.map_err(|source| {
|
||||
ValidationError::ConstExpression { handle, source }
|
||||
.with_span_handle(handle, &module.const_expressions)
|
||||
})?
|
||||
}
|
||||
|
||||
for (handle, constant) in module.constants.iter() {
|
||||
self.validate_constant(handle, module.to_ctx(), &mod_info)
|
||||
self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
|
||||
.map_err(|source| {
|
||||
ValidationError::Constant {
|
||||
handle,
|
||||
@@ -490,7 +509,7 @@ impl Validator {
|
||||
}
|
||||
|
||||
for (var_handle, var) in module.global_variables.iter() {
|
||||
self.validate_global_var(var, module.to_ctx(), &mod_info)
|
||||
self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
|
||||
.map_err(|source| {
|
||||
ValidationError::GlobalVariable {
|
||||
handle: var_handle,
|
||||
@@ -502,7 +521,7 @@ impl Validator {
|
||||
}
|
||||
|
||||
for (handle, fun) in module.functions.iter() {
|
||||
match self.validate_function(fun, module, &mod_info, false) {
|
||||
match self.validate_function(fun, module, &mod_info, false, &global_expr_kind) {
|
||||
Ok(info) => mod_info.functions.push(info),
|
||||
Err(error) => {
|
||||
return Err(error.and_then(|source| {
|
||||
@@ -528,7 +547,7 @@ impl Validator {
|
||||
.with_span()); // TODO: keep some EP span information?
|
||||
}
|
||||
|
||||
match self.validate_entry_point(ep, module, &mod_info) {
|
||||
match self.validate_entry_point(ep, module, &mod_info, &global_expr_kind) {
|
||||
Ok(info) => mod_info.entry_points.push(info),
|
||||
Err(error) => {
|
||||
return Err(error.and_then(|source| {
|
||||
|
||||
@@ -63,6 +63,7 @@ bitflags::bitflags! {
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, thiserror::Error)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
pub enum Disalignment {
|
||||
#[error("The array stride {stride} is not a multiple of the required alignment {alignment}")]
|
||||
ArrayStride { stride: u32, alignment: Alignment },
|
||||
@@ -87,6 +88,7 @@ pub enum Disalignment {
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, thiserror::Error)]
|
||||
#[cfg_attr(test, derive(PartialEq))]
|
||||
pub enum TypeError {
|
||||
#[error("Capability {0:?} is required")]
|
||||
MissingCapability(Capabilities),
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
override depth: f32; // Specified at the API level using
|
||||
// the name "depth".
|
||||
// Must be overridden.
|
||||
// override height = 2 * depth; // The default value
|
||||
override height = 2 * depth; // The default value
|
||||
// (if not set at the API level),
|
||||
// depends on another
|
||||
// overridable constant.
|
||||
|
||||
@@ -33,6 +33,12 @@
|
||||
kind: Float,
|
||||
width: 4,
|
||||
))),
|
||||
Handle(2),
|
||||
Value(Scalar((
|
||||
kind: Float,
|
||||
width: 4,
|
||||
))),
|
||||
Handle(2),
|
||||
Value(Scalar((
|
||||
kind: Float,
|
||||
width: 4,
|
||||
|
||||
@@ -3,6 +3,7 @@ static const float specular_param = 2.3;
|
||||
static const float gain = 1.1;
|
||||
static const float width = 0.0;
|
||||
static const float depth = 2.3;
|
||||
static const float height = 4.6;
|
||||
static const float inferred_f32_ = 2.718;
|
||||
|
||||
[numthreads(1, 1, 1)]
|
||||
|
||||
@@ -52,11 +52,17 @@
|
||||
ty: 2,
|
||||
init: None,
|
||||
),
|
||||
(
|
||||
name: Some("height"),
|
||||
id: None,
|
||||
ty: 2,
|
||||
init: Some(6),
|
||||
),
|
||||
(
|
||||
name: Some("inferred_f32"),
|
||||
id: None,
|
||||
ty: 2,
|
||||
init: Some(4),
|
||||
init: Some(7),
|
||||
),
|
||||
],
|
||||
global_variables: [],
|
||||
@@ -64,6 +70,13 @@
|
||||
Literal(Bool(true)),
|
||||
Literal(F32(2.3)),
|
||||
Literal(F32(0.0)),
|
||||
Override(5),
|
||||
Literal(F32(2.0)),
|
||||
Binary(
|
||||
op: Multiply,
|
||||
left: 5,
|
||||
right: 4,
|
||||
),
|
||||
Literal(F32(2.718)),
|
||||
],
|
||||
functions: [],
|
||||
|
||||
@@ -52,11 +52,17 @@
|
||||
ty: 2,
|
||||
init: None,
|
||||
),
|
||||
(
|
||||
name: Some("height"),
|
||||
id: None,
|
||||
ty: 2,
|
||||
init: Some(6),
|
||||
),
|
||||
(
|
||||
name: Some("inferred_f32"),
|
||||
id: None,
|
||||
ty: 2,
|
||||
init: Some(4),
|
||||
init: Some(7),
|
||||
),
|
||||
],
|
||||
global_variables: [],
|
||||
@@ -64,6 +70,13 @@
|
||||
Literal(Bool(true)),
|
||||
Literal(F32(2.3)),
|
||||
Literal(F32(0.0)),
|
||||
Override(5),
|
||||
Literal(F32(2.0)),
|
||||
Binary(
|
||||
op: Multiply,
|
||||
left: 5,
|
||||
right: 4,
|
||||
),
|
||||
Literal(F32(2.718)),
|
||||
],
|
||||
functions: [],
|
||||
|
||||
@@ -9,6 +9,7 @@ constant float specular_param = 2.3;
|
||||
constant float gain = 1.1;
|
||||
constant float width = 0.0;
|
||||
constant float depth = 2.3;
|
||||
constant float height = 4.6;
|
||||
constant float inferred_f32_ = 2.718;
|
||||
|
||||
kernel void main_(
|
||||
|
||||
@@ -1,25 +1,27 @@
|
||||
; SPIR-V
|
||||
; Version: 1.0
|
||||
; Generator: rspirv
|
||||
; Bound: 15
|
||||
; Bound: 17
|
||||
OpCapability Shader
|
||||
%1 = OpExtInstImport "GLSL.std.450"
|
||||
OpMemoryModel Logical GLSL450
|
||||
OpEntryPoint GLCompute %12 "main"
|
||||
OpExecutionMode %12 LocalSize 1 1 1
|
||||
OpEntryPoint GLCompute %14 "main"
|
||||
OpExecutionMode %14 LocalSize 1 1 1
|
||||
%2 = OpTypeVoid
|
||||
%3 = OpTypeBool
|
||||
%4 = OpTypeFloat 32
|
||||
%5 = OpConstantTrue %3
|
||||
%6 = OpConstant %4 2.3
|
||||
%7 = OpConstant %4 0.0
|
||||
%8 = OpConstant %4 2.718
|
||||
%9 = OpConstantFalse %3
|
||||
%10 = OpConstant %4 1.1
|
||||
%13 = OpTypeFunction %2
|
||||
%12 = OpFunction %2 None %13
|
||||
%11 = OpLabel
|
||||
OpBranch %14
|
||||
%14 = OpLabel
|
||||
%8 = OpConstantFalse %3
|
||||
%9 = OpConstant %4 1.1
|
||||
%10 = OpConstant %4 2.0
|
||||
%11 = OpConstant %4 4.6
|
||||
%12 = OpConstant %4 2.718
|
||||
%15 = OpTypeFunction %2
|
||||
%14 = OpFunction %2 None %15
|
||||
%13 = OpLabel
|
||||
OpBranch %16
|
||||
%16 = OpLabel
|
||||
OpReturn
|
||||
OpFunctionEnd
|
||||
@@ -14,8 +14,8 @@ fn main_1() {
|
||||
let _e4 = a_uv_1;
|
||||
v_uv = _e4;
|
||||
let _e6 = a_pos_1;
|
||||
let _e8 = (c_scale * _e6);
|
||||
gl_Position = vec4<f32>(_e8.x, _e8.y, 0f, 1f);
|
||||
let _e7 = (c_scale * _e6);
|
||||
gl_Position = vec4<f32>(_e7.x, _e7.y, 0f, 1f);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user