implement override-expression evaluation for initializers of override declarations

This commit is contained in:
teoxoy
2024-02-14 15:17:07 +01:00
committed by Teodor Tanasoaia
parent ff332afdef
commit d6ebd88f42
32 changed files with 909 additions and 242 deletions

View File

@@ -122,6 +122,7 @@ impl<T> Handle<T> {
serde(transparent)
)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
#[cfg_attr(test, derive(PartialEq))]
pub struct Range<T> {
inner: ops::Range<u32>,
#[cfg_attr(any(feature = "serialize", feature = "deserialize"), serde(skip))]
@@ -140,6 +141,7 @@ impl<T> Range<T> {
// NOTE: Keep this diagnostic in sync with that of [`BadHandle`].
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
#[error("Handle range {range:?} of {kind} is either not present, or inaccessible yet")]
pub struct BadRangeError {
// This error is used for many `Handle` types, but there's no point in making this generic, so

View File

@@ -256,7 +256,7 @@ pub enum Error {
#[error("{0}")]
Custom(String),
#[error(transparent)]
PipelineConstant(#[from] back::pipeline_constants::PipelineConstantError),
PipelineConstant(#[from] Box<back::pipeline_constants::PipelineConstantError>),
}
#[derive(Default)]

View File

@@ -169,9 +169,14 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
module_info: &valid::ModuleInfo,
pipeline_options: &PipelineOptions,
) -> Result<super::ReflectionInfo, Error> {
let module =
back::pipeline_constants::process_overrides(module, &pipeline_options.constants)?;
let (module, module_info) = back::pipeline_constants::process_overrides(
module,
module_info,
&pipeline_options.constants,
)
.map_err(Box::new)?;
let module = module.as_ref();
let module_info = module_info.as_ref();
self.reset(module);

View File

@@ -144,7 +144,7 @@ pub enum Error {
#[error("ray tracing is not supported prior to MSL 2.3")]
UnsupportedRayTracing,
#[error(transparent)]
PipelineConstant(#[from] crate::back::pipeline_constants::PipelineConstantError),
PipelineConstant(#[from] Box<crate::back::pipeline_constants::PipelineConstantError>),
}
#[derive(Clone, Debug, PartialEq, thiserror::Error)]

View File

@@ -3223,9 +3223,11 @@ impl<W: Write> Writer<W> {
options: &Options,
pipeline_options: &PipelineOptions,
) -> Result<TranslationInfo, Error> {
let module =
back::pipeline_constants::process_overrides(module, &pipeline_options.constants)?;
let (module, info) =
back::pipeline_constants::process_overrides(module, info, &pipeline_options.constants)
.map_err(Box::new)?;
let module = module.as_ref();
let info = info.as_ref();
self.names.clear();
self.namer.reset(

View File

@@ -1,6 +1,10 @@
use super::PipelineConstants;
use crate::{Constant, Expression, Literal, Module, Scalar, Span, TypeInner};
use std::borrow::Cow;
use crate::{
proc::{ConstantEvaluator, ConstantEvaluatorError},
valid::{Capabilities, ModuleInfo, ValidationError, ValidationFlags, Validator},
Constant, Expression, Handle, Literal, Module, Override, Scalar, Span, TypeInner, WithSpan,
};
use std::{borrow::Cow, collections::HashSet};
use thiserror::Error;
#[derive(Error, Debug, Clone)]
@@ -12,48 +16,317 @@ pub enum PipelineConstantError {
SrcNeedsToBeFinite,
#[error("Source f64 value doesn't fit in destination")]
DstRangeTooSmall,
#[error(transparent)]
ConstantEvaluatorError(#[from] ConstantEvaluatorError),
#[error(transparent)]
ValidationError(#[from] WithSpan<ValidationError>),
}
pub(super) fn process_overrides<'a>(
module: &'a Module,
module_info: &'a ModuleInfo,
pipeline_constants: &PipelineConstants,
) -> Result<Cow<'a, Module>, PipelineConstantError> {
) -> Result<(Cow<'a, Module>, Cow<'a, ModuleInfo>), PipelineConstantError> {
if module.overrides.is_empty() {
return Ok(Cow::Borrowed(module));
return Ok((Cow::Borrowed(module), Cow::Borrowed(module_info)));
}
let mut module = module.clone();
let mut override_map = Vec::with_capacity(module.overrides.len());
let mut adjusted_const_expressions = Vec::with_capacity(module.const_expressions.len());
let mut adjusted_constant_initializers = HashSet::with_capacity(module.constants.len());
for (_handle, override_, span) in module.overrides.drain() {
let key = if let Some(id) = override_.id {
Cow::Owned(id.to_string())
} else if let Some(ref name) = override_.name {
Cow::Borrowed(name)
} else {
unreachable!();
let mut global_expression_kind_tracker = crate::proc::ExpressionConstnessTracker::new();
let mut override_iter = module.overrides.drain();
for (old_h, expr, span) in module.const_expressions.drain() {
let mut expr = match expr {
Expression::Override(h) => {
let c_h = if let Some(new_h) = override_map.get(h.index()) {
*new_h
} else {
let mut new_h = None;
for entry in override_iter.by_ref() {
let stop = entry.0 == h;
new_h = Some(process_override(
entry,
pipeline_constants,
&mut module,
&mut override_map,
&adjusted_const_expressions,
&mut adjusted_constant_initializers,
&mut global_expression_kind_tracker,
)?);
if stop {
break;
}
}
new_h.unwrap()
};
Expression::Constant(c_h)
}
Expression::Constant(c_h) => {
adjusted_constant_initializers.insert(c_h);
module.constants[c_h].init = adjusted_const_expressions[c_h.index()];
expr
}
expr => expr,
};
let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
let literal = match module.types[override_.ty].inner {
TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
_ => unreachable!(),
};
module
.const_expressions
.append(Expression::Literal(literal), Span::UNDEFINED)
} else if let Some(init) = override_.init {
init
} else {
return Err(PipelineConstantError::MissingValue(key.to_string()));
};
let constant = Constant {
name: override_.name,
ty: override_.ty,
init,
};
module.constants.append(constant, span);
let mut evaluator = ConstantEvaluator::for_wgsl_module(
&mut module,
&mut global_expression_kind_tracker,
false,
);
adjust_expr(&adjusted_const_expressions, &mut expr);
let h = evaluator.try_eval_and_append(expr, span)?;
debug_assert_eq!(old_h.index(), adjusted_const_expressions.len());
adjusted_const_expressions.push(h);
}
Ok(Cow::Owned(module))
for entry in override_iter {
process_override(
entry,
pipeline_constants,
&mut module,
&mut override_map,
&adjusted_const_expressions,
&mut adjusted_constant_initializers,
&mut global_expression_kind_tracker,
)?;
}
for (_, c) in module
.constants
.iter_mut()
.filter(|&(c_h, _)| !adjusted_constant_initializers.contains(&c_h))
{
c.init = adjusted_const_expressions[c.init.index()];
}
for (_, v) in module.global_variables.iter_mut() {
if let Some(ref mut init) = v.init {
*init = adjusted_const_expressions[init.index()];
}
}
let mut validator = Validator::new(ValidationFlags::all(), Capabilities::all());
let module_info = validator.validate(&module)?;
Ok((Cow::Owned(module), Cow::Owned(module_info)))
}
fn process_override(
(old_h, override_, span): (Handle<Override>, Override, Span),
pipeline_constants: &PipelineConstants,
module: &mut Module,
override_map: &mut Vec<Handle<Constant>>,
adjusted_const_expressions: &[Handle<Expression>],
adjusted_constant_initializers: &mut HashSet<Handle<Constant>>,
global_expression_kind_tracker: &mut crate::proc::ExpressionConstnessTracker,
) -> Result<Handle<Constant>, PipelineConstantError> {
let key = if let Some(id) = override_.id {
Cow::Owned(id.to_string())
} else if let Some(ref name) = override_.name {
Cow::Borrowed(name)
} else {
unreachable!();
};
let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
let literal = match module.types[override_.ty].inner {
TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
_ => unreachable!(),
};
let expr = module
.const_expressions
.append(Expression::Literal(literal), Span::UNDEFINED);
global_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Const);
expr
} else if let Some(init) = override_.init {
adjusted_const_expressions[init.index()]
} else {
return Err(PipelineConstantError::MissingValue(key.to_string()));
};
let constant = Constant {
name: override_.name,
ty: override_.ty,
init,
};
let h = module.constants.append(constant, span);
debug_assert_eq!(old_h.index(), override_map.len());
override_map.push(h);
adjusted_constant_initializers.insert(h);
Ok(h)
}
fn adjust_expr(new_pos: &[Handle<Expression>], expr: &mut Expression) {
let adjust = |expr: &mut Handle<Expression>| {
*expr = new_pos[expr.index()];
};
match *expr {
Expression::Compose {
ref mut components, ..
} => {
for c in components.iter_mut() {
adjust(c);
}
}
Expression::Access {
ref mut base,
ref mut index,
} => {
adjust(base);
adjust(index);
}
Expression::AccessIndex { ref mut base, .. } => {
adjust(base);
}
Expression::Splat { ref mut value, .. } => {
adjust(value);
}
Expression::Swizzle { ref mut vector, .. } => {
adjust(vector);
}
Expression::Load { ref mut pointer } => {
adjust(pointer);
}
Expression::ImageSample {
ref mut image,
ref mut sampler,
ref mut coordinate,
ref mut array_index,
ref mut offset,
ref mut level,
ref mut depth_ref,
..
} => {
adjust(image);
adjust(sampler);
adjust(coordinate);
if let Some(e) = array_index.as_mut() {
adjust(e);
}
if let Some(e) = offset.as_mut() {
adjust(e);
}
match *level {
crate::SampleLevel::Exact(ref mut expr)
| crate::SampleLevel::Bias(ref mut expr) => {
adjust(expr);
}
crate::SampleLevel::Gradient {
ref mut x,
ref mut y,
} => {
adjust(x);
adjust(y);
}
_ => {}
}
if let Some(e) = depth_ref.as_mut() {
adjust(e);
}
}
Expression::ImageLoad {
ref mut image,
ref mut coordinate,
ref mut array_index,
ref mut sample,
ref mut level,
} => {
adjust(image);
adjust(coordinate);
if let Some(e) = array_index.as_mut() {
adjust(e);
}
if let Some(e) = sample.as_mut() {
adjust(e);
}
if let Some(e) = level.as_mut() {
adjust(e);
}
}
Expression::ImageQuery {
ref mut image,
ref mut query,
} => {
adjust(image);
match *query {
crate::ImageQuery::Size { ref mut level } => {
if let Some(e) = level.as_mut() {
adjust(e);
}
}
_ => {}
}
}
Expression::Unary { ref mut expr, .. } => {
adjust(expr);
}
Expression::Binary {
ref mut left,
ref mut right,
..
} => {
adjust(left);
adjust(right);
}
Expression::Select {
ref mut condition,
ref mut accept,
ref mut reject,
} => {
adjust(condition);
adjust(accept);
adjust(reject);
}
Expression::Derivative { ref mut expr, .. } => {
adjust(expr);
}
Expression::Relational {
ref mut argument, ..
} => {
adjust(argument);
}
Expression::Math {
ref mut arg,
ref mut arg1,
ref mut arg2,
ref mut arg3,
..
} => {
adjust(arg);
if let Some(e) = arg1.as_mut() {
adjust(e);
}
if let Some(e) = arg2.as_mut() {
adjust(e);
}
if let Some(e) = arg3.as_mut() {
adjust(e);
}
}
Expression::As { ref mut expr, .. } => {
adjust(expr);
}
Expression::ArrayLength(ref mut expr) => {
adjust(expr);
}
Expression::RayQueryGetIntersection { ref mut query, .. } => {
adjust(query);
}
Expression::Literal(_)
| Expression::FunctionArgument(_)
| Expression::GlobalVariable(_)
| Expression::LocalVariable(_)
| Expression::CallResult(_)
| Expression::RayQueryProceedResult
| Expression::Constant(_)
| Expression::Override(_)
| Expression::ZeroValue(_)
| Expression::AtomicResult { .. }
| Expression::WorkGroupUniformLoadResult { .. } => {}
}
}
fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {

View File

@@ -71,7 +71,7 @@ pub enum Error {
#[error("module is not validated properly: {0}")]
Validation(&'static str),
#[error(transparent)]
PipelineConstant(#[from] crate::back::pipeline_constants::PipelineConstantError),
PipelineConstant(#[from] Box<crate::back::pipeline_constants::PipelineConstantError>),
}
#[derive(Default)]
@@ -529,6 +529,42 @@ struct FunctionArgument {
handle_id: Word,
}
/// Tracks the expressions for which the backend emits the following instructions:
/// - OpConstantTrue
/// - OpConstantFalse
/// - OpConstant
/// - OpConstantComposite
/// - OpConstantNull
struct ExpressionConstnessTracker {
inner: bit_set::BitSet,
}
impl ExpressionConstnessTracker {
fn from_arena(arena: &crate::Arena<crate::Expression>) -> Self {
let mut inner = bit_set::BitSet::new();
for (handle, expr) in arena.iter() {
let insert = match *expr {
crate::Expression::Literal(_)
| crate::Expression::ZeroValue(_)
| crate::Expression::Constant(_) => true,
crate::Expression::Compose { ref components, .. } => {
components.iter().all(|h| inner.contains(h.index()))
}
crate::Expression::Splat { value, .. } => inner.contains(value.index()),
_ => false,
};
if insert {
inner.insert(handle.index());
}
}
Self { inner }
}
fn is_const(&self, value: Handle<crate::Expression>) -> bool {
self.inner.contains(value.index())
}
}
/// General information needed to emit SPIR-V for Naga statements.
struct BlockContext<'w> {
/// The writer handling the module to which this code belongs.
@@ -554,7 +590,7 @@ struct BlockContext<'w> {
temp_list: Vec<Word>,
/// Tracks the constness of `Expression`s residing in `self.ir_function.expressions`
expression_constness: crate::proc::ExpressionConstnessTracker,
expression_constness: ExpressionConstnessTracker,
}
impl BlockContext<'_> {

View File

@@ -615,7 +615,7 @@ impl Writer {
// Steal the Writer's temp list for a bit.
temp_list: std::mem::take(&mut self.temp_list),
writer: self,
expression_constness: crate::proc::ExpressionConstnessTracker::from_arena(
expression_constness: super::ExpressionConstnessTracker::from_arena(
&ir_function.expressions,
),
};
@@ -2029,15 +2029,21 @@ impl Writer {
debug_info: &Option<DebugInfo>,
words: &mut Vec<Word>,
) -> Result<(), Error> {
let ir_module = if let Some(pipeline_options) = pipeline_options {
let (ir_module, info) = if let Some(pipeline_options) = pipeline_options {
crate::back::pipeline_constants::process_overrides(
ir_module,
info,
&pipeline_options.constants,
)?
)
.map_err(Box::new)?
} else {
std::borrow::Cow::Borrowed(ir_module)
(
std::borrow::Cow::Borrowed(ir_module),
std::borrow::Cow::Borrowed(info),
)
};
let ir_module = ir_module.as_ref();
let info = info.as_ref();
self.reset();

View File

@@ -77,12 +77,19 @@ pub struct Context<'a> {
pub body: Block,
pub module: &'a mut crate::Module,
pub is_const: bool,
/// Tracks the constness of `Expression`s residing in `self.expressions`
pub expression_constness: crate::proc::ExpressionConstnessTracker,
/// Tracks the expression kind of `Expression`s residing in `self.expressions`
pub local_expression_kind_tracker: crate::proc::ExpressionConstnessTracker,
/// Tracks the expression kind of `Expression`s residing in `self.module.const_expressions`
pub global_expression_kind_tracker: &'a mut crate::proc::ExpressionConstnessTracker,
}
impl<'a> Context<'a> {
pub fn new(frontend: &Frontend, module: &'a mut crate::Module, is_const: bool) -> Result<Self> {
pub fn new(
frontend: &Frontend,
module: &'a mut crate::Module,
is_const: bool,
global_expression_kind_tracker: &'a mut crate::proc::ExpressionConstnessTracker,
) -> Result<Self> {
let mut this = Context {
expressions: Arena::new(),
locals: Arena::new(),
@@ -101,7 +108,8 @@ impl<'a> Context<'a> {
body: Block::new(),
module,
is_const: false,
expression_constness: crate::proc::ExpressionConstnessTracker::new(),
local_expression_kind_tracker: crate::proc::ExpressionConstnessTracker::new(),
global_expression_kind_tracker,
};
this.emit_start();
@@ -249,12 +257,15 @@ impl<'a> Context<'a> {
pub fn add_expression(&mut self, expr: Expression, meta: Span) -> Result<Handle<Expression>> {
let mut eval = if self.is_const {
crate::proc::ConstantEvaluator::for_glsl_module(self.module)
crate::proc::ConstantEvaluator::for_glsl_module(
self.module,
self.global_expression_kind_tracker,
)
} else {
crate::proc::ConstantEvaluator::for_glsl_function(
self.module,
&mut self.expressions,
&mut self.expression_constness,
&mut self.local_expression_kind_tracker,
&mut self.emitter,
&mut self.body,
)

View File

@@ -1236,6 +1236,8 @@ impl Frontend {
let pointer = ctx
.expressions
.append(Expression::GlobalVariable(arg.handle), Default::default());
ctx.local_expression_kind_tracker
.insert(pointer, crate::proc::ExpressionKind::Runtime);
let ty = ctx.module.global_variables[arg.handle].ty;
@@ -1256,6 +1258,8 @@ impl Frontend {
let value = ctx
.expressions
.append(Expression::FunctionArgument(idx), Default::default());
ctx.local_expression_kind_tracker
.insert(value, crate::proc::ExpressionKind::Runtime);
ctx.body
.push(Statement::Store { pointer, value }, Default::default());
},
@@ -1285,6 +1289,8 @@ impl Frontend {
let pointer = ctx
.expressions
.append(Expression::GlobalVariable(arg.handle), Default::default());
ctx.local_expression_kind_tracker
.insert(pointer, crate::proc::ExpressionKind::Runtime);
let ty = ctx.module.global_variables[arg.handle].ty;
@@ -1307,6 +1313,8 @@ impl Frontend {
let load = ctx
.expressions
.append(Expression::Load { pointer }, Default::default());
ctx.local_expression_kind_tracker
.insert(load, crate::proc::ExpressionKind::Runtime);
ctx.body.push(
Statement::Emit(ctx.expressions.range_from(len)),
Default::default(),
@@ -1329,6 +1337,8 @@ impl Frontend {
let res = ctx
.expressions
.append(Expression::Compose { ty, components }, Default::default());
ctx.local_expression_kind_tracker
.insert(res, crate::proc::ExpressionKind::Runtime);
ctx.body.push(
Statement::Emit(ctx.expressions.range_from(len)),
Default::default(),

View File

@@ -164,9 +164,15 @@ impl<'source> ParsingContext<'source> {
pub fn parse(&mut self, frontend: &mut Frontend) -> Result<Module> {
let mut module = Module::default();
let mut global_expression_kind_tracker = crate::proc::ExpressionConstnessTracker::new();
// Body and expression arena for global initialization
let mut ctx = Context::new(frontend, &mut module, false)?;
let mut ctx = Context::new(
frontend,
&mut module,
false,
&mut global_expression_kind_tracker,
)?;
while self.peek(frontend).is_some() {
self.parse_external_declaration(frontend, &mut ctx)?;
@@ -196,7 +202,11 @@ impl<'source> ParsingContext<'source> {
frontend: &mut Frontend,
ctx: &mut Context,
) -> Result<(u32, Span)> {
let (const_expr, meta) = self.parse_constant_expression(frontend, ctx.module)?;
let (const_expr, meta) = self.parse_constant_expression(
frontend,
ctx.module,
ctx.global_expression_kind_tracker,
)?;
let res = ctx.module.to_ctx().eval_expr_to_u32(const_expr);
@@ -219,8 +229,9 @@ impl<'source> ParsingContext<'source> {
&mut self,
frontend: &mut Frontend,
module: &mut Module,
global_expression_kind_tracker: &mut crate::proc::ExpressionConstnessTracker,
) -> Result<(Handle<Expression>, Span)> {
let mut ctx = Context::new(frontend, module, true)?;
let mut ctx = Context::new(frontend, module, true, global_expression_kind_tracker)?;
let mut stmt_ctx = ctx.stmt_ctx();
let expr = self.parse_conditional(frontend, &mut ctx, &mut stmt_ctx, None)?;

View File

@@ -251,7 +251,7 @@ impl<'source> ParsingContext<'source> {
init.and_then(|expr| ctx.ctx.lift_up_const_expression(expr).ok());
late_initializer = None;
} else if let Some(init) = init {
if ctx.is_inside_loop || !ctx.ctx.expression_constness.is_const(init) {
if ctx.is_inside_loop || !ctx.ctx.local_expression_kind_tracker.is_const(init) {
decl_initializer = None;
late_initializer = Some(init);
} else {
@@ -326,7 +326,12 @@ impl<'source> ParsingContext<'source> {
let result = ty.map(|ty| FunctionResult { ty, binding: None });
let mut context = Context::new(frontend, ctx.module, false)?;
let mut context = Context::new(
frontend,
ctx.module,
false,
ctx.global_expression_kind_tracker,
)?;
self.parse_function_args(frontend, &mut context)?;

View File

@@ -192,8 +192,11 @@ impl<'source> ParsingContext<'source> {
TokenValue::Case => {
self.bump(frontend)?;
let (const_expr, meta) =
self.parse_constant_expression(frontend, ctx.module)?;
let (const_expr, meta) = self.parse_constant_expression(
frontend,
ctx.module,
ctx.global_expression_kind_tracker,
)?;
match ctx.module.const_expressions[const_expr] {
Expression::Literal(Literal::I32(value)) => match uint {

View File

@@ -330,7 +330,7 @@ impl Context<'_> {
expr: Handle<Expression>,
) -> Result<Handle<Expression>> {
let meta = self.expressions.get_span(expr);
Ok(match self.expressions[expr] {
let h = match self.expressions[expr] {
ref expr @ (Expression::Literal(_)
| Expression::Constant(_)
| Expression::ZeroValue(_)) => self.module.const_expressions.append(expr.clone(), meta),
@@ -355,6 +355,9 @@ impl Context<'_> {
meta,
})
}
})
};
self.global_expression_kind_tracker
.insert(h, crate::proc::ExpressionKind::Const);
Ok(h)
}
}

View File

@@ -86,6 +86,8 @@ pub struct GlobalContext<'source, 'temp, 'out> {
module: &'out mut crate::Module,
const_typifier: &'temp mut Typifier,
global_expression_kind_tracker: &'temp mut crate::proc::ExpressionConstnessTracker,
}
impl<'source> GlobalContext<'source, '_, '_> {
@@ -97,6 +99,19 @@ impl<'source> GlobalContext<'source, '_, '_> {
module: self.module,
const_typifier: self.const_typifier,
expr_type: ExpressionContextType::Constant,
global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
fn as_override(&mut self) -> ExpressionContext<'source, '_, '_> {
ExpressionContext {
ast_expressions: self.ast_expressions,
globals: self.globals,
types: self.types,
module: self.module,
const_typifier: self.const_typifier,
expr_type: ExpressionContextType::Override,
global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@@ -165,6 +180,7 @@ pub struct StatementContext<'source, 'temp, 'out> {
/// we should consider them to be const. See the use of `force_non_const` in
/// the code for lowering `let` bindings.
expression_constness: &'temp mut crate::proc::ExpressionConstnessTracker,
global_expression_kind_tracker: &'temp mut crate::proc::ExpressionConstnessTracker,
}
impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
@@ -181,6 +197,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
types: self.types,
ast_expressions: self.ast_expressions,
const_typifier: self.const_typifier,
global_expression_kind_tracker: self.global_expression_kind_tracker,
module: self.module,
expr_type: ExpressionContextType::Runtime(RuntimeExpressionContext {
local_table: self.local_table,
@@ -200,6 +217,7 @@ impl<'a, 'temp> StatementContext<'a, 'temp, '_> {
types: self.types,
module: self.module,
const_typifier: self.const_typifier,
global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@@ -253,6 +271,14 @@ pub enum ExpressionContextType<'temp, 'out> {
/// available in the [`ExpressionContext`], so this variant
/// carries no further information.
Constant,
/// We are lowering to an override expression, to be included in the module's
/// constant expression arena.
///
/// Everything override expressions are allowed to refer to is
/// available in the [`ExpressionContext`], so this variant
/// carries no further information.
Override,
}
/// State for lowering an [`ast::Expression`] to Naga IR.
@@ -311,6 +337,7 @@ pub struct ExpressionContext<'source, 'temp, 'out> {
///
/// [`module::const_expressions`]: crate::Module::const_expressions
const_typifier: &'temp mut Typifier,
global_expression_kind_tracker: &'temp mut crate::proc::ExpressionConstnessTracker,
/// Whether we are lowering a constant expression or a general
/// runtime expression, and the data needed in each case.
@@ -326,6 +353,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
const_typifier: self.const_typifier,
module: self.module,
expr_type: ExpressionContextType::Constant,
global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@@ -336,6 +364,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
types: self.types,
module: self.module,
const_typifier: self.const_typifier,
global_expression_kind_tracker: self.global_expression_kind_tracker,
}
}
@@ -348,7 +377,16 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
rctx.emitter,
rctx.block,
),
ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(self.module),
ExpressionContextType::Constant => ConstantEvaluator::for_wgsl_module(
self.module,
self.global_expression_kind_tracker,
false,
),
ExpressionContextType::Override => ConstantEvaluator::for_wgsl_module(
self.module,
self.global_expression_kind_tracker,
true,
),
}
}
@@ -375,20 +413,25 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
.ok()
}
ExpressionContextType::Constant => self.module.to_ctx().eval_expr_to_u32(handle).ok(),
ExpressionContextType::Override => None,
}
}
fn get_expression_span(&self, handle: Handle<crate::Expression>) -> Span {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => ctx.function.expressions.get_span(handle),
ExpressionContextType::Constant => self.module.const_expressions.get_span(handle),
ExpressionContextType::Constant | ExpressionContextType::Override => {
self.module.const_expressions.get_span(handle)
}
}
}
fn typifier(&self) -> &Typifier {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => ctx.typifier,
ExpressionContextType::Constant => self.const_typifier,
ExpressionContextType::Constant | ExpressionContextType::Override => {
self.const_typifier
}
}
}
@@ -398,7 +441,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
) -> Result<&mut RuntimeExpressionContext<'temp, 'out>, Error<'source>> {
match self.expr_type {
ExpressionContextType::Runtime(ref mut ctx) => Ok(ctx),
ExpressionContextType::Constant => Err(Error::UnexpectedOperationInConstContext(span)),
ExpressionContextType::Constant | ExpressionContextType::Override => {
Err(Error::UnexpectedOperationInConstContext(span))
}
}
}
@@ -435,7 +480,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
}
// This means a `gather` operation appeared in a constant expression.
// This error refers to the `gather` itself, not its "component" argument.
ExpressionContextType::Constant => {
ExpressionContextType::Constant | ExpressionContextType::Override => {
Err(Error::UnexpectedOperationInConstContext(gather_span))
}
}
@@ -461,7 +506,9 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
// to also borrow self.module.types mutably below.
let typifier = match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => ctx.typifier,
ExpressionContextType::Constant => &*self.const_typifier,
ExpressionContextType::Constant | ExpressionContextType::Override => {
&*self.const_typifier
}
};
Ok(typifier.register_type(handle, &mut self.module.types))
}
@@ -504,7 +551,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
typifier = &mut *ctx.typifier;
expressions = &ctx.function.expressions;
}
ExpressionContextType::Constant => {
ExpressionContextType::Constant | ExpressionContextType::Override => {
resolve_ctx = ResolveContext::with_locals(self.module, &empty_arena, &[]);
typifier = self.const_typifier;
expressions = &self.module.const_expressions;
@@ -600,14 +647,14 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
rctx.block
.extend(rctx.emitter.finish(&rctx.function.expressions));
}
ExpressionContextType::Constant => {}
ExpressionContextType::Constant | ExpressionContextType::Override => {}
}
let result = self.append_expression(expression, span);
match self.expr_type {
ExpressionContextType::Runtime(ref mut rctx) => {
rctx.emitter.start(&rctx.function.expressions);
}
ExpressionContextType::Constant => {}
ExpressionContextType::Constant | ExpressionContextType::Override => {}
}
result
}
@@ -852,6 +899,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
types: &tu.types,
module: &mut module,
const_typifier: &mut Typifier::new(),
global_expression_kind_tracker: &mut crate::proc::ExpressionConstnessTracker::new(),
};
for decl_handle in self.index.visit_ordered() {
@@ -959,7 +1007,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ast::GlobalDeclKind::Override(ref o) => {
let init = o
.init
.map(|init| self.expression(init, &mut ctx.as_const()))
.map(|init| self.expression(init, &mut ctx.as_override()))
.transpose()?;
let inferred_type = init
.map(|init| ctx.as_const().register_type(init))
@@ -1049,6 +1097,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let mut local_table = FastHashMap::default();
let mut expressions = Arena::new();
let mut named_expressions = FastIndexMap::default();
let mut local_expression_kind_tracker = crate::proc::ExpressionConstnessTracker::new();
let arguments = f
.arguments
@@ -1060,6 +1109,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.append(crate::Expression::FunctionArgument(i as u32), arg.name.span);
local_table.insert(arg.handle, Typed::Plain(expr));
named_expressions.insert(expr, (arg.name.name.to_string(), arg.name.span));
local_expression_kind_tracker.insert(expr, crate::proc::ExpressionKind::Runtime);
Ok(crate::FunctionArgument {
name: Some(arg.name.name.to_string()),
@@ -1102,7 +1152,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
named_expressions: &mut named_expressions,
types: ctx.types,
module: ctx.module,
expression_constness: &mut crate::proc::ExpressionConstnessTracker::new(),
expression_constness: &mut local_expression_kind_tracker,
global_expression_kind_tracker: ctx.global_expression_kind_tracker,
};
let mut body = self.block(&f.body, false, &mut stmt_ctx)?;
ensure_block_returns(&mut body);
@@ -1518,6 +1569,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.function
.expressions
.append(crate::Expression::Binary { op, left, right }, stmt.span);
rctx.expression_constness
.insert(left, crate::proc::ExpressionKind::Runtime);
rctx.expression_constness
.insert(value, crate::proc::ExpressionKind::Runtime);
block.extend(emitter.finish(&ctx.function.expressions));
crate::Statement::Store {
@@ -1611,7 +1666,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
LoweredGlobalDecl::Const(handle) => {
Typed::Plain(crate::Expression::Constant(handle))
}
_ => {
LoweredGlobalDecl::Override(handle) => {
Typed::Plain(crate::Expression::Override(handle))
}
LoweredGlobalDecl::Function(_)
| LoweredGlobalDecl::Type(_)
| LoweredGlobalDecl::EntryPoint => {
return Err(Error::Unexpected(span, ExpectedToken::Variable));
}
};
@@ -1886,9 +1946,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
rctx.block
.extend(rctx.emitter.finish(&rctx.function.expressions));
let result = has_result.then(|| {
rctx.function
let result = rctx
.function
.expressions
.append(crate::Expression::CallResult(function), span)
.append(crate::Expression::CallResult(function), span);
rctx.expression_constness
.insert(result, crate::proc::ExpressionKind::Runtime);
result
});
rctx.emitter.start(&rctx.function.expressions);
rctx.block.push(

View File

@@ -253,9 +253,9 @@ gen_component_wise_extractor! {
}
#[derive(Debug)]
enum Behavior {
Wgsl,
Glsl,
enum Behavior<'a> {
Wgsl(WgslRestrictions<'a>),
Glsl(GlslRestrictions<'a>),
}
/// A context for evaluating constant expressions.
@@ -278,7 +278,7 @@ enum Behavior {
#[derive(Debug)]
pub struct ConstantEvaluator<'a> {
/// Which language's evaluation rules we should follow.
behavior: Behavior,
behavior: Behavior<'a>,
/// The module's type arena.
///
@@ -297,65 +297,145 @@ pub struct ConstantEvaluator<'a> {
/// The arena to which we are contributing expressions.
expressions: &'a mut Arena<Expression>,
/// When `self.expressions` refers to a function's local expression
/// arena, this needs to be populated
function_local_data: Option<FunctionLocalData<'a>>,
/// Tracks the constness of expressions residing in [`Self::expressions`]
expression_kind_tracker: &'a mut ExpressionConstnessTracker,
}
#[derive(Debug)]
enum WgslRestrictions<'a> {
/// - const-expressions will be evaluated and inserted in the arena
Const,
/// - const-expressions will be evaluated and inserted in the arena
/// - override-expressions will be inserted in the arena
Override,
/// - const-expressions will be evaluated and inserted in the arena
/// - override-expressions will be inserted in the arena
/// - runtime-expressions will be inserted in the arena
Runtime(FunctionLocalData<'a>),
}
#[derive(Debug)]
enum GlslRestrictions<'a> {
/// - const-expressions will be evaluated and inserted in the arena
Const,
/// - const-expressions will be evaluated and inserted in the arena
/// - override-expressions will be inserted in the arena
/// - runtime-expressions will be inserted in the arena
Runtime(FunctionLocalData<'a>),
}
#[derive(Debug)]
struct FunctionLocalData<'a> {
/// Global constant expressions
const_expressions: &'a Arena<Expression>,
/// Tracks the constness of expressions residing in `ConstantEvaluator.expressions`
expression_constness: &'a mut ExpressionConstnessTracker,
emitter: &'a mut super::Emitter,
block: &'a mut crate::Block,
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
pub enum ExpressionKind {
Const,
Override,
Runtime,
}
#[derive(Debug)]
pub struct ExpressionConstnessTracker {
inner: bit_set::BitSet,
inner: Vec<ExpressionKind>,
}
impl ExpressionConstnessTracker {
pub fn new() -> Self {
Self {
inner: bit_set::BitSet::new(),
}
pub const fn new() -> Self {
Self { inner: Vec::new() }
}
/// Forces the the expression to not be const
pub fn force_non_const(&mut self, value: Handle<Expression>) {
self.inner.remove(value.index());
self.inner[value.index()] = ExpressionKind::Runtime;
}
fn insert(&mut self, value: Handle<Expression>) {
self.inner.insert(value.index());
pub fn insert(&mut self, value: Handle<Expression>, expr_type: ExpressionKind) {
assert_eq!(self.inner.len(), value.index());
self.inner.push(expr_type);
}
pub fn is_const(&self, h: Handle<Expression>) -> bool {
matches!(self.type_of(h), ExpressionKind::Const)
}
pub fn is_const(&self, value: Handle<Expression>) -> bool {
self.inner.contains(value.index())
pub fn is_const_or_override(&self, h: Handle<Expression>) -> bool {
matches!(
self.type_of(h),
ExpressionKind::Const | ExpressionKind::Override
)
}
fn type_of(&self, value: Handle<Expression>) -> ExpressionKind {
self.inner[value.index()]
}
pub fn from_arena(arena: &Arena<Expression>) -> Self {
let mut tracker = Self::new();
for (handle, expr) in arena.iter() {
let insert = match *expr {
crate::Expression::Literal(_)
| crate::Expression::ZeroValue(_)
| crate::Expression::Constant(_) => true,
crate::Expression::Compose { ref components, .. } => {
components.iter().all(|h| tracker.is_const(*h))
}
crate::Expression::Splat { value, .. } => tracker.is_const(value),
_ => false,
};
if insert {
tracker.insert(handle);
}
let mut tracker = Self {
inner: Vec::with_capacity(arena.len()),
};
for (_, expr) in arena.iter() {
tracker.inner.push(tracker.type_of_with_expr(expr));
}
tracker
}
fn type_of_with_expr(&self, expr: &Expression) -> ExpressionKind {
match *expr {
Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
ExpressionKind::Const
}
Expression::Override(_) => ExpressionKind::Override,
Expression::Compose { ref components, .. } => {
let mut expr_type = ExpressionKind::Const;
for component in components {
expr_type = expr_type.max(self.type_of(*component))
}
expr_type
}
Expression::Splat { value, .. } => self.type_of(value),
Expression::AccessIndex { base, .. } => self.type_of(base),
Expression::Access { base, index } => self.type_of(base).max(self.type_of(index)),
Expression::Swizzle { vector, .. } => self.type_of(vector),
Expression::Unary { expr, .. } => self.type_of(expr),
Expression::Binary { left, right, .. } => self.type_of(left).max(self.type_of(right)),
Expression::Math {
arg,
arg1,
arg2,
arg3,
..
} => self
.type_of(arg)
.max(
arg1.map(|arg| self.type_of(arg))
.unwrap_or(ExpressionKind::Const),
)
.max(
arg2.map(|arg| self.type_of(arg))
.unwrap_or(ExpressionKind::Const),
)
.max(
arg3.map(|arg| self.type_of(arg))
.unwrap_or(ExpressionKind::Const),
),
Expression::As { expr, .. } => self.type_of(expr),
Expression::Select {
condition,
accept,
reject,
} => self
.type_of(condition)
.max(self.type_of(accept))
.max(self.type_of(reject)),
Expression::Relational { argument, .. } => self.type_of(argument),
Expression::ArrayLength(expr) => self.type_of(expr),
_ => ExpressionKind::Runtime,
}
}
}
#[derive(Clone, Debug, thiserror::Error)]
@@ -436,6 +516,12 @@ pub enum ConstantEvaluatorError {
ShiftedMoreThan32Bits,
#[error(transparent)]
Literal(#[from] crate::valid::LiteralError),
#[error("Can't use pipeline-overridable constants in const-expressions")]
Override,
#[error("Unexpected runtime-expression")]
RuntimeExpr,
#[error("Unexpected override-expression")]
OverrideExpr,
}
impl<'a> ConstantEvaluator<'a> {
@@ -443,26 +529,49 @@ impl<'a> ConstantEvaluator<'a> {
/// constant expression arena.
///
/// Report errors according to WGSL's rules for constant evaluation.
pub fn for_wgsl_module(module: &'a mut crate::Module) -> Self {
Self::for_module(Behavior::Wgsl, module)
pub fn for_wgsl_module(
module: &'a mut crate::Module,
global_expression_kind_tracker: &'a mut ExpressionConstnessTracker,
in_override_ctx: bool,
) -> Self {
Self::for_module(
Behavior::Wgsl(if in_override_ctx {
WgslRestrictions::Override
} else {
WgslRestrictions::Const
}),
module,
global_expression_kind_tracker,
)
}
/// Return a [`ConstantEvaluator`] that will add expressions to `module`'s
/// constant expression arena.
///
/// Report errors according to GLSL's rules for constant evaluation.
pub fn for_glsl_module(module: &'a mut crate::Module) -> Self {
Self::for_module(Behavior::Glsl, module)
pub fn for_glsl_module(
module: &'a mut crate::Module,
global_expression_kind_tracker: &'a mut ExpressionConstnessTracker,
) -> Self {
Self::for_module(
Behavior::Glsl(GlslRestrictions::Const),
module,
global_expression_kind_tracker,
)
}
fn for_module(behavior: Behavior, module: &'a mut crate::Module) -> Self {
fn for_module(
behavior: Behavior<'a>,
module: &'a mut crate::Module,
global_expression_kind_tracker: &'a mut ExpressionConstnessTracker,
) -> Self {
Self {
behavior,
types: &mut module.types,
constants: &module.constants,
overrides: &module.overrides,
expressions: &mut module.const_expressions,
function_local_data: None,
expression_kind_tracker: global_expression_kind_tracker,
}
}
@@ -473,18 +582,22 @@ impl<'a> ConstantEvaluator<'a> {
pub fn for_wgsl_function(
module: &'a mut crate::Module,
expressions: &'a mut Arena<Expression>,
expression_constness: &'a mut ExpressionConstnessTracker,
local_expression_kind_tracker: &'a mut ExpressionConstnessTracker,
emitter: &'a mut super::Emitter,
block: &'a mut crate::Block,
) -> Self {
Self::for_function(
Behavior::Wgsl,
module,
Self {
behavior: Behavior::Wgsl(WgslRestrictions::Runtime(FunctionLocalData {
const_expressions: &module.const_expressions,
emitter,
block,
})),
types: &mut module.types,
constants: &module.constants,
overrides: &module.overrides,
expressions,
expression_constness,
emitter,
block,
)
expression_kind_tracker: local_expression_kind_tracker,
}
}
/// Return a [`ConstantEvaluator`] that will add expressions to `function`'s
@@ -494,40 +607,21 @@ impl<'a> ConstantEvaluator<'a> {
pub fn for_glsl_function(
module: &'a mut crate::Module,
expressions: &'a mut Arena<Expression>,
expression_constness: &'a mut ExpressionConstnessTracker,
emitter: &'a mut super::Emitter,
block: &'a mut crate::Block,
) -> Self {
Self::for_function(
Behavior::Glsl,
module,
expressions,
expression_constness,
emitter,
block,
)
}
fn for_function(
behavior: Behavior,
module: &'a mut crate::Module,
expressions: &'a mut Arena<Expression>,
expression_constness: &'a mut ExpressionConstnessTracker,
local_expression_kind_tracker: &'a mut ExpressionConstnessTracker,
emitter: &'a mut super::Emitter,
block: &'a mut crate::Block,
) -> Self {
Self {
behavior,
behavior: Behavior::Glsl(GlslRestrictions::Runtime(FunctionLocalData {
const_expressions: &module.const_expressions,
emitter,
block,
})),
types: &mut module.types,
constants: &module.constants,
overrides: &module.overrides,
expressions,
function_local_data: Some(FunctionLocalData {
const_expressions: &module.const_expressions,
expression_constness,
emitter,
block,
}),
expression_kind_tracker: local_expression_kind_tracker,
}
}
@@ -536,19 +630,17 @@ impl<'a> ConstantEvaluator<'a> {
types: self.types,
constants: self.constants,
overrides: self.overrides,
const_expressions: match self.function_local_data {
Some(ref data) => data.const_expressions,
const_expressions: match self.function_local_data() {
Some(data) => data.const_expressions,
None => self.expressions,
},
}
}
fn check(&self, expr: Handle<Expression>) -> Result<(), ConstantEvaluatorError> {
if let Some(ref function_local_data) = self.function_local_data {
if !function_local_data.expression_constness.is_const(expr) {
log::debug!("check: SubexpressionsAreNotConstant");
return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
}
if !self.expression_kind_tracker.is_const(expr) {
log::debug!("check: SubexpressionsAreNotConstant");
return Err(ConstantEvaluatorError::SubexpressionsAreNotConstant);
}
Ok(())
}
@@ -561,7 +653,7 @@ impl<'a> ConstantEvaluator<'a> {
Expression::Constant(c) => {
// Are we working in a function's expression arena, or the
// module's constant expression arena?
if let Some(ref function_local_data) = self.function_local_data {
if let Some(function_local_data) = self.function_local_data() {
// Deep-copy the constant's value into our arena.
self.copy_from(
self.constants[c].init,
@@ -607,14 +699,56 @@ impl<'a> ConstantEvaluator<'a> {
expr: Expression,
span: Span,
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
let res = self.try_eval_and_append_impl(&expr, span);
if self.function_local_data.is_some() {
match res {
Ok(h) => Ok(h),
Err(_) => Ok(self.append_expr(expr, span, false)),
match (
&self.behavior,
self.expression_kind_tracker.type_of_with_expr(&expr),
) {
// avoid errors on unimplemented functionality if possible
(
&Behavior::Wgsl(WgslRestrictions::Runtime(_))
| &Behavior::Glsl(GlslRestrictions::Runtime(_)),
ExpressionKind::Const,
) => match self.try_eval_and_append_impl(&expr, span) {
Err(
ConstantEvaluatorError::NotImplemented(_)
| ConstantEvaluatorError::InvalidBinaryOpArgs,
) => Ok(self.append_expr(expr, span, ExpressionKind::Runtime)),
res => res,
},
(_, ExpressionKind::Const) => self.try_eval_and_append_impl(&expr, span),
(&Behavior::Wgsl(WgslRestrictions::Const), ExpressionKind::Override) => {
Err(ConstantEvaluatorError::OverrideExpr)
}
} else {
res
(
&Behavior::Wgsl(WgslRestrictions::Override | WgslRestrictions::Runtime(_)),
ExpressionKind::Override,
) => Ok(self.append_expr(expr, span, ExpressionKind::Override)),
(&Behavior::Glsl(_), ExpressionKind::Override) => unreachable!(),
(
&Behavior::Wgsl(WgslRestrictions::Runtime(_))
| &Behavior::Glsl(GlslRestrictions::Runtime(_)),
ExpressionKind::Runtime,
) => Ok(self.append_expr(expr, span, ExpressionKind::Runtime)),
(_, ExpressionKind::Runtime) => Err(ConstantEvaluatorError::RuntimeExpr),
}
}
/// Is the [`Self::expressions`] arena the global module expression arena?
const fn is_global_arena(&self) -> bool {
matches!(
self.behavior,
Behavior::Wgsl(WgslRestrictions::Const | WgslRestrictions::Override)
| Behavior::Glsl(GlslRestrictions::Const)
)
}
const fn function_local_data(&self) -> Option<&FunctionLocalData<'a>> {
match self.behavior {
Behavior::Wgsl(WgslRestrictions::Runtime(ref function_local_data))
| Behavior::Glsl(GlslRestrictions::Runtime(ref function_local_data)) => {
Some(function_local_data)
}
_ => None,
}
}
@@ -625,14 +759,12 @@ impl<'a> ConstantEvaluator<'a> {
) -> Result<Handle<Expression>, ConstantEvaluatorError> {
log::trace!("try_eval_and_append: {:?}", expr);
match *expr {
Expression::Constant(c) if self.function_local_data.is_none() => {
Expression::Constant(c) if self.is_global_arena() => {
// "See through" the constant and use its initializer.
// This is mainly done to avoid having constants pointing to other constants.
Ok(self.constants[c].init)
}
Expression::Override(_) => Err(ConstantEvaluatorError::NotImplemented(
"overrides are WIP".into(),
)),
Expression::Override(_) => Err(ConstantEvaluatorError::Override),
Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
self.register_evaluated_expr(expr.clone(), span)
}
@@ -713,8 +845,8 @@ impl<'a> ConstantEvaluator<'a> {
format!("{fun:?} built-in function"),
)),
Expression::ArrayLength(expr) => match self.behavior {
Behavior::Wgsl => Err(ConstantEvaluatorError::ArrayLength),
Behavior::Glsl => {
Behavior::Wgsl(_) => Err(ConstantEvaluatorError::ArrayLength),
Behavior::Glsl(_) => {
let expr = self.check_and_get(expr)?;
self.array_length(expr, span)
}
@@ -1881,34 +2013,35 @@ impl<'a> ConstantEvaluator<'a> {
crate::valid::check_literal_value(literal)?;
}
Ok(self.append_expr(expr, span, true))
Ok(self.append_expr(expr, span, ExpressionKind::Const))
}
fn append_expr(&mut self, expr: Expression, span: Span, is_const: bool) -> Handle<Expression> {
if let Some(FunctionLocalData {
ref mut emitter,
ref mut block,
ref mut expression_constness,
..
}) = self.function_local_data
{
let is_running = emitter.is_running();
let needs_pre_emit = expr.needs_pre_emit();
let h = if is_running && needs_pre_emit {
block.extend(emitter.finish(self.expressions));
let h = self.expressions.append(expr, span);
emitter.start(self.expressions);
h
} else {
self.expressions.append(expr, span)
};
if is_const {
expression_constness.insert(h);
fn append_expr(
&mut self,
expr: Expression,
span: Span,
expr_type: ExpressionKind,
) -> Handle<Expression> {
let h = match self.behavior {
Behavior::Wgsl(WgslRestrictions::Runtime(ref mut function_local_data))
| Behavior::Glsl(GlslRestrictions::Runtime(ref mut function_local_data)) => {
let is_running = function_local_data.emitter.is_running();
let needs_pre_emit = expr.needs_pre_emit();
if is_running && needs_pre_emit {
function_local_data
.block
.extend(function_local_data.emitter.finish(self.expressions));
let h = self.expressions.append(expr, span);
function_local_data.emitter.start(self.expressions);
h
} else {
self.expressions.append(expr, span)
}
}
h
} else {
self.expressions.append(expr, span)
}
_ => self.expressions.append(expr, span),
};
self.expression_kind_tracker.insert(h, expr_type);
h
}
fn resolve_type(
@@ -2062,7 +2195,7 @@ mod tests {
UniqueArena, VectorSize,
};
use super::{Behavior, ConstantEvaluator};
use super::{Behavior, ConstantEvaluator, ExpressionConstnessTracker, WgslRestrictions};
#[test]
fn unary_op() {
@@ -2143,13 +2276,15 @@ mod tests {
expr: expr1,
};
let expression_kind_tracker =
&mut ExpressionConstnessTracker::from_arena(&const_expressions);
let mut solver = ConstantEvaluator {
behavior: Behavior::Wgsl,
behavior: Behavior::Wgsl(WgslRestrictions::Const),
types: &mut types,
constants: &constants,
overrides: &overrides,
expressions: &mut const_expressions,
function_local_data: None,
expression_kind_tracker,
};
let res1 = solver
@@ -2228,13 +2363,15 @@ mod tests {
convert: Some(crate::BOOL_WIDTH),
};
let expression_kind_tracker =
&mut ExpressionConstnessTracker::from_arena(&const_expressions);
let mut solver = ConstantEvaluator {
behavior: Behavior::Wgsl,
behavior: Behavior::Wgsl(WgslRestrictions::Const),
types: &mut types,
constants: &constants,
overrides: &overrides,
expressions: &mut const_expressions,
function_local_data: None,
expression_kind_tracker,
};
let res = solver
@@ -2345,13 +2482,15 @@ mod tests {
let base = const_expressions.append(Expression::Constant(h), Default::default());
let expression_kind_tracker =
&mut ExpressionConstnessTracker::from_arena(&const_expressions);
let mut solver = ConstantEvaluator {
behavior: Behavior::Wgsl,
behavior: Behavior::Wgsl(WgslRestrictions::Const),
types: &mut types,
constants: &constants,
overrides: &overrides,
expressions: &mut const_expressions,
function_local_data: None,
expression_kind_tracker,
};
let root1 = Expression::AccessIndex { base, index: 1 };
@@ -2437,13 +2576,15 @@ mod tests {
let h_expr = const_expressions.append(Expression::Constant(h), Default::default());
let expression_kind_tracker =
&mut ExpressionConstnessTracker::from_arena(&const_expressions);
let mut solver = ConstantEvaluator {
behavior: Behavior::Wgsl,
behavior: Behavior::Wgsl(WgslRestrictions::Const),
types: &mut types,
constants: &constants,
overrides: &overrides,
expressions: &mut const_expressions,
function_local_data: None,
expression_kind_tracker,
};
let solved_compose = solver
@@ -2518,13 +2659,15 @@ mod tests {
let h_expr = const_expressions.append(Expression::Constant(h), Default::default());
let expression_kind_tracker =
&mut ExpressionConstnessTracker::from_arena(&const_expressions);
let mut solver = ConstantEvaluator {
behavior: Behavior::Wgsl,
behavior: Behavior::Wgsl(WgslRestrictions::Const),
types: &mut types,
constants: &constants,
overrides: &overrides,
expressions: &mut const_expressions,
function_local_data: None,
expression_kind_tracker,
};
let solved_compose = solver

View File

@@ -11,7 +11,7 @@ mod terminator;
mod typifier;
pub use constant_evaluator::{
ConstantEvaluator, ConstantEvaluatorError, ExpressionConstnessTracker,
ConstantEvaluator, ConstantEvaluatorError, ExpressionConstnessTracker, ExpressionKind,
};
pub use emitter::Emitter;
pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};

View File

@@ -226,7 +226,7 @@ struct Sampling {
sampler: GlobalOrArgument,
}
#[derive(Debug)]
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct FunctionInfo {

View File

@@ -90,6 +90,8 @@ pub enum ExpressionError {
sampler: bool,
has_ref: bool,
},
#[error("Sample offset must be a const-expression")]
InvalidSampleOffsetExprType,
#[error("Sample offset constant {1:?} doesn't match the image dimension {0:?}")]
InvalidSampleOffset(crate::ImageDimension, Handle<crate::Expression>),
#[error("Depth reference {0:?} is not a scalar float")]
@@ -129,9 +131,10 @@ pub enum ExpressionError {
}
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum ConstExpressionError {
#[error("The expression is not a constant expression")]
NonConst,
#[error("The expression is not a constant or override expression")]
NonConstOrOverride,
#[error(transparent)]
Compose(#[from] super::ComposeError),
#[error("Splatting {0:?} can't be done")]
@@ -184,9 +187,14 @@ impl super::Validator {
handle: Handle<crate::Expression>,
gctx: crate::proc::GlobalCtx,
mod_info: &ModuleInfo,
global_expr_kind: &crate::proc::ExpressionConstnessTracker,
) -> Result<(), ConstExpressionError> {
use crate::Expression as E;
if !global_expr_kind.is_const_or_override(handle) {
return Err(super::ConstExpressionError::NonConstOrOverride);
}
match gctx.const_expressions[handle] {
E::Literal(literal) => {
self.validate_literal(literal)?;
@@ -203,12 +211,14 @@ impl super::Validator {
crate::TypeInner::Scalar { .. } => {}
_ => return Err(super::ConstExpressionError::InvalidSplatType(value)),
},
_ => return Err(super::ConstExpressionError::NonConst),
// the constant evaluator will report errors about override-expressions
_ => {}
}
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub(super) fn validate_expression(
&self,
root: Handle<crate::Expression>,
@@ -217,6 +227,7 @@ impl super::Validator {
module: &crate::Module,
info: &FunctionInfo,
mod_info: &ModuleInfo,
global_expr_kind: &crate::proc::ExpressionConstnessTracker,
) -> Result<ShaderStages, ExpressionError> {
use crate::{Expression as E, Scalar as Sc, ScalarKind as Sk, TypeInner as Ti};
@@ -462,6 +473,10 @@ impl super::Validator {
// check constant offset
if let Some(const_expr) = offset {
if !global_expr_kind.is_const(const_expr) {
return Err(ExpressionError::InvalidSampleOffsetExprType);
}
match *mod_info[const_expr].inner_with(&module.types) {
Ti::Scalar(Sc { kind: Sk::Sint, .. }) if num_components == 1 => {}
Ti::Vector {

View File

@@ -927,7 +927,7 @@ impl super::Validator {
var: &crate::LocalVariable,
gctx: crate::proc::GlobalCtx,
fun_info: &FunctionInfo,
expression_constness: &crate::proc::ExpressionConstnessTracker,
local_expr_kind: &crate::proc::ExpressionConstnessTracker,
) -> Result<(), LocalVariableError> {
log::debug!("var {:?}", var);
let type_info = self
@@ -945,7 +945,7 @@ impl super::Validator {
return Err(LocalVariableError::InitializerType);
}
if !expression_constness.is_const(init) {
if !local_expr_kind.is_const(init) {
return Err(LocalVariableError::NonConstInitializer);
}
}
@@ -959,14 +959,14 @@ impl super::Validator {
module: &crate::Module,
mod_info: &ModuleInfo,
entry_point: bool,
global_expr_kind: &crate::proc::ExpressionConstnessTracker,
) -> Result<FunctionInfo, WithSpan<FunctionError>> {
let mut info = mod_info.process_function(fun, module, self.flags, self.capabilities)?;
let expression_constness =
crate::proc::ExpressionConstnessTracker::from_arena(&fun.expressions);
let local_expr_kind = crate::proc::ExpressionConstnessTracker::from_arena(&fun.expressions);
for (var_handle, var) in fun.local_variables.iter() {
self.validate_local_var(var, module.to_ctx(), &info, &expression_constness)
self.validate_local_var(var, module.to_ctx(), &info, &local_expr_kind)
.map_err(|source| {
FunctionError::LocalVariable {
handle: var_handle,
@@ -1032,7 +1032,15 @@ impl super::Validator {
self.valid_expression_set.insert(handle.index());
}
if self.flags.contains(super::ValidationFlags::EXPRESSIONS) {
match self.validate_expression(handle, expr, fun, module, &info, mod_info) {
match self.validate_expression(
handle,
expr,
fun,
module,
&info,
mod_info,
global_expr_kind,
) {
Ok(stages) => info.available_stages &= stages,
Err(source) => {
return Err(FunctionError::Expression { handle, source }

View File

@@ -592,6 +592,7 @@ impl From<BadRangeError> for ValidationError {
}
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum InvalidHandleError {
#[error(transparent)]
BadHandle(#[from] BadHandle),
@@ -602,6 +603,7 @@ pub enum InvalidHandleError {
}
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
#[error(
"{subject:?} of kind {subject_kind:?} depends on {depends_on:?} of kind {depends_on_kind}, \
which has not been processed yet"

View File

@@ -10,6 +10,7 @@ use bit_set::BitSet;
const MAX_WORKGROUP_SIZE: u32 = 0x4000;
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum GlobalVariableError {
#[error("Usage isn't compatible with address space {0:?}")]
InvalidUsage(crate::AddressSpace),
@@ -30,6 +31,8 @@ pub enum GlobalVariableError {
Handle<crate::Type>,
#[source] Disalignment,
),
#[error("Initializer must be a const-expression")]
InitializerExprType,
#[error("Initializer doesn't match the variable type")]
InitializerType,
#[error("Initializer can't be used with address space {0:?}")]
@@ -39,6 +42,7 @@ pub enum GlobalVariableError {
}
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum VaryingError {
#[error("The type {0:?} does not match the varying")]
InvalidType(Handle<crate::Type>),
@@ -76,6 +80,7 @@ pub enum VaryingError {
}
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum EntryPointError {
#[error("Multiple conflicting entry points")]
Conflict,
@@ -395,6 +400,7 @@ impl super::Validator {
var: &crate::GlobalVariable,
gctx: crate::proc::GlobalCtx,
mod_info: &ModuleInfo,
global_expr_kind: &crate::proc::ExpressionConstnessTracker,
) -> Result<(), GlobalVariableError> {
use super::TypeFlags;
@@ -523,6 +529,10 @@ impl super::Validator {
}
}
if !global_expr_kind.is_const(init) {
return Err(GlobalVariableError::InitializerExprType);
}
let decl_ty = &gctx.types[var.ty].inner;
let init_ty = mod_info[init].inner_with(gctx.types);
if !decl_ty.equivalent(init_ty, gctx.types) {
@@ -538,6 +548,7 @@ impl super::Validator {
ep: &crate::EntryPoint,
module: &crate::Module,
mod_info: &ModuleInfo,
global_expr_kind: &crate::proc::ExpressionConstnessTracker,
) -> Result<FunctionInfo, WithSpan<EntryPointError>> {
if ep.early_depth_test.is_some() {
let required = Capabilities::EARLY_DEPTH_TEST;
@@ -566,7 +577,7 @@ impl super::Validator {
}
let mut info = self
.validate_function(&ep.function, module, mod_info, true)
.validate_function(&ep.function, module, mod_info, true, global_expr_kind)
.map_err(WithSpan::into_other)?;
{

View File

@@ -12,7 +12,7 @@ mod r#type;
use crate::{
arena::Handle,
proc::{LayoutError, Layouter, TypeResolution},
proc::{ExpressionConstnessTracker, LayoutError, Layouter, TypeResolution},
FastHashSet,
};
use bit_set::BitSet;
@@ -131,7 +131,7 @@ bitflags::bitflags! {
}
}
#[derive(Debug)]
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct ModuleInfo {
@@ -178,7 +178,10 @@ pub struct Validator {
}
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum ConstantError {
#[error("Initializer must be a const-expression")]
InitializerExprType,
#[error("The type doesn't match the constant")]
InvalidType,
#[error("The type is not constructible")]
@@ -186,11 +189,14 @@ pub enum ConstantError {
}
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum OverrideError {
#[error("Override name and ID are missing")]
MissingNameAndID,
#[error("Override ID must be unique")]
DuplicateID,
#[error("Initializer must be a const-expression or override-expression")]
InitializerExprType,
#[error("The type doesn't match the override")]
InvalidType,
#[error("The type is not constructible")]
@@ -200,6 +206,7 @@ pub enum OverrideError {
}
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum ValidationError {
#[error(transparent)]
InvalidHandle(#[from] InvalidHandleError),
@@ -335,6 +342,7 @@ impl Validator {
handle: Handle<crate::Constant>,
gctx: crate::proc::GlobalCtx,
mod_info: &ModuleInfo,
global_expr_kind: &ExpressionConstnessTracker,
) -> Result<(), ConstantError> {
let con = &gctx.constants[handle];
@@ -343,6 +351,10 @@ impl Validator {
return Err(ConstantError::NonConstructibleType);
}
if !global_expr_kind.is_const(con.init) {
return Err(ConstantError::InitializerExprType);
}
let decl_ty = &gctx.types[con.ty].inner;
let init_ty = mod_info[con.init].inner_with(gctx.types);
if !decl_ty.equivalent(init_ty, gctx.types) {
@@ -455,17 +467,24 @@ impl Validator {
}
}
let global_expr_kind = ExpressionConstnessTracker::from_arena(&module.const_expressions);
if self.flags.contains(ValidationFlags::CONSTANTS) {
for (handle, _) in module.const_expressions.iter() {
self.validate_const_expression(handle, module.to_ctx(), &mod_info)
.map_err(|source| {
ValidationError::ConstExpression { handle, source }
.with_span_handle(handle, &module.const_expressions)
})?
self.validate_const_expression(
handle,
module.to_ctx(),
&mod_info,
&global_expr_kind,
)
.map_err(|source| {
ValidationError::ConstExpression { handle, source }
.with_span_handle(handle, &module.const_expressions)
})?
}
for (handle, constant) in module.constants.iter() {
self.validate_constant(handle, module.to_ctx(), &mod_info)
self.validate_constant(handle, module.to_ctx(), &mod_info, &global_expr_kind)
.map_err(|source| {
ValidationError::Constant {
handle,
@@ -490,7 +509,7 @@ impl Validator {
}
for (var_handle, var) in module.global_variables.iter() {
self.validate_global_var(var, module.to_ctx(), &mod_info)
self.validate_global_var(var, module.to_ctx(), &mod_info, &global_expr_kind)
.map_err(|source| {
ValidationError::GlobalVariable {
handle: var_handle,
@@ -502,7 +521,7 @@ impl Validator {
}
for (handle, fun) in module.functions.iter() {
match self.validate_function(fun, module, &mod_info, false) {
match self.validate_function(fun, module, &mod_info, false, &global_expr_kind) {
Ok(info) => mod_info.functions.push(info),
Err(error) => {
return Err(error.and_then(|source| {
@@ -528,7 +547,7 @@ impl Validator {
.with_span()); // TODO: keep some EP span information?
}
match self.validate_entry_point(ep, module, &mod_info) {
match self.validate_entry_point(ep, module, &mod_info, &global_expr_kind) {
Ok(info) => mod_info.entry_points.push(info),
Err(error) => {
return Err(error.and_then(|source| {

View File

@@ -63,6 +63,7 @@ bitflags::bitflags! {
}
#[derive(Clone, Copy, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum Disalignment {
#[error("The array stride {stride} is not a multiple of the required alignment {alignment}")]
ArrayStride { stride: u32, alignment: Alignment },
@@ -87,6 +88,7 @@ pub enum Disalignment {
}
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum TypeError {
#[error("Capability {0:?} is required")]
MissingCapability(Capabilities),

View File

@@ -6,7 +6,7 @@
override depth: f32; // Specified at the API level using
// the name "depth".
// Must be overridden.
// override height = 2 * depth; // The default value
override height = 2 * depth; // The default value
// (if not set at the API level),
// depends on another
// overridable constant.

View File

@@ -33,6 +33,12 @@
kind: Float,
width: 4,
))),
Handle(2),
Value(Scalar((
kind: Float,
width: 4,
))),
Handle(2),
Value(Scalar((
kind: Float,
width: 4,

View File

@@ -3,6 +3,7 @@ static const float specular_param = 2.3;
static const float gain = 1.1;
static const float width = 0.0;
static const float depth = 2.3;
static const float height = 4.6;
static const float inferred_f32_ = 2.718;
[numthreads(1, 1, 1)]

View File

@@ -52,11 +52,17 @@
ty: 2,
init: None,
),
(
name: Some("height"),
id: None,
ty: 2,
init: Some(6),
),
(
name: Some("inferred_f32"),
id: None,
ty: 2,
init: Some(4),
init: Some(7),
),
],
global_variables: [],
@@ -64,6 +70,13 @@
Literal(Bool(true)),
Literal(F32(2.3)),
Literal(F32(0.0)),
Override(5),
Literal(F32(2.0)),
Binary(
op: Multiply,
left: 5,
right: 4,
),
Literal(F32(2.718)),
],
functions: [],

View File

@@ -52,11 +52,17 @@
ty: 2,
init: None,
),
(
name: Some("height"),
id: None,
ty: 2,
init: Some(6),
),
(
name: Some("inferred_f32"),
id: None,
ty: 2,
init: Some(4),
init: Some(7),
),
],
global_variables: [],
@@ -64,6 +70,13 @@
Literal(Bool(true)),
Literal(F32(2.3)),
Literal(F32(0.0)),
Override(5),
Literal(F32(2.0)),
Binary(
op: Multiply,
left: 5,
right: 4,
),
Literal(F32(2.718)),
],
functions: [],

View File

@@ -9,6 +9,7 @@ constant float specular_param = 2.3;
constant float gain = 1.1;
constant float width = 0.0;
constant float depth = 2.3;
constant float height = 4.6;
constant float inferred_f32_ = 2.718;
kernel void main_(

View File

@@ -1,25 +1,27 @@
; SPIR-V
; Version: 1.0
; Generator: rspirv
; Bound: 15
; Bound: 17
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %12 "main"
OpExecutionMode %12 LocalSize 1 1 1
OpEntryPoint GLCompute %14 "main"
OpExecutionMode %14 LocalSize 1 1 1
%2 = OpTypeVoid
%3 = OpTypeBool
%4 = OpTypeFloat 32
%5 = OpConstantTrue %3
%6 = OpConstant %4 2.3
%7 = OpConstant %4 0.0
%8 = OpConstant %4 2.718
%9 = OpConstantFalse %3
%10 = OpConstant %4 1.1
%13 = OpTypeFunction %2
%12 = OpFunction %2 None %13
%11 = OpLabel
OpBranch %14
%14 = OpLabel
%8 = OpConstantFalse %3
%9 = OpConstant %4 1.1
%10 = OpConstant %4 2.0
%11 = OpConstant %4 4.6
%12 = OpConstant %4 2.718
%15 = OpTypeFunction %2
%14 = OpFunction %2 None %15
%13 = OpLabel
OpBranch %16
%16 = OpLabel
OpReturn
OpFunctionEnd

View File

@@ -14,8 +14,8 @@ fn main_1() {
let _e4 = a_uv_1;
v_uv = _e4;
let _e6 = a_pos_1;
let _e8 = (c_scale * _e6);
gl_Position = vec4<f32>(_e8.x, _e8.y, 0f, 1f);
let _e7 = (c_scale * _e6);
gl_Position = vec4<f32>(_e7.x, _e7.y, 0f, 1f);
return;
}