[wgsl-in] eagerly evaluate const-expressions

[wgsl-in] support const-expressions in attributes

allow `Splat` as an evaluated const-expression type
This commit is contained in:
teoxoy
2023-04-19 14:55:07 +02:00
committed by Teodor Tanasoaia
parent c5d884fca8
commit a730236b68
44 changed files with 2917 additions and 2770 deletions

View File

@@ -1709,7 +1709,7 @@ impl<'a, W: Write> Writer<'a, W> {
arg: Handle<crate::Expression>,
arg1: Handle<crate::Expression>,
size: usize,
ctx: &back::FunctionCtx<'_>,
ctx: &back::FunctionCtx,
) -> BackendResult {
// Write parantheses around the dot product expression to prevent operators
// with different precedences from applying earlier.
@@ -2254,9 +2254,12 @@ impl<'a, W: Write> Writer<'a, W> {
/// [`Expression`]: crate::Expression
/// [`Module`]: crate::Module
fn write_const_expr(&mut self, expr: Handle<crate::Expression>) -> BackendResult {
self.write_possibly_const_expr(expr, &self.module.const_expressions, |writer, expr| {
writer.write_const_expr(expr)
})
self.write_possibly_const_expr(
expr,
&self.module.const_expressions,
|expr| &self.info[expr],
|writer, expr| writer.write_const_expr(expr),
)
}
/// Write [`Expression`] variants that can occur in both runtime and const expressions.
@@ -2277,13 +2280,15 @@ impl<'a, W: Write> Writer<'a, W> {
/// Adds no newlines or leading/trailing whitespace
///
/// [`Expression`]: crate::Expression
fn write_possibly_const_expr<E>(
&mut self,
fn write_possibly_const_expr<'w, I, E>(
&'w mut self,
expr: Handle<crate::Expression>,
expressions: &crate::Arena<crate::Expression>,
info: I,
write_expression: E,
) -> BackendResult
where
I: Fn(Handle<crate::Expression>) -> &'w proc::TypeResolution,
E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
{
use crate::Expression;
@@ -2331,6 +2336,14 @@ impl<'a, W: Write> Writer<'a, W> {
}
write!(self.out, ")")?
}
// `Splat` needs to actually write down a vector, it's not always inferred in GLSL.
Expression::Splat { size: _, value } => {
let resolved = info(expr).inner_with(&self.module.types);
self.write_value_type(resolved)?;
write!(self.out, "(")?;
write_expression(self, value)?;
write!(self.out, ")")?
}
_ => unreachable!(),
}
@@ -2344,7 +2357,7 @@ impl<'a, W: Write> Writer<'a, W> {
fn write_expr(
&mut self,
expr: Handle<crate::Expression>,
ctx: &back::FunctionCtx<'_>,
ctx: &back::FunctionCtx,
) -> BackendResult {
use crate::Expression;
@@ -2357,10 +2370,14 @@ impl<'a, W: Write> Writer<'a, W> {
Expression::Literal(_)
| Expression::Constant(_)
| Expression::ZeroValue(_)
| Expression::Compose { .. } => {
self.write_possibly_const_expr(expr, ctx.expressions, |writer, expr| {
writer.write_expr(expr, ctx)
})?;
| Expression::Compose { .. }
| Expression::Splat { .. } => {
self.write_possibly_const_expr(
expr,
ctx.expressions,
|expr| &ctx.info[expr].ty,
|writer, expr| writer.write_expr(expr, ctx),
)?;
}
// `Access` is applied to arrays, vectors and matrices and is written as indexing
Expression::Access { base, index } => {
@@ -2407,14 +2424,6 @@ impl<'a, W: Write> Writer<'a, W> {
ref other => return Err(Error::Custom(format!("Cannot index {other:?}"))),
}
}
// `Splat` needs to actually write down a vector, it's not always inferred in GLSL.
Expression::Splat { size: _, value } => {
let resolved = ctx.info[expr].ty.inner_with(&self.module.types);
self.write_value_type(resolved)?;
write!(self.out, "(")?;
self.write_expr(value, ctx)?;
write!(self.out, ")")?
}
// `Swizzle` adds a few letters behind the dot.
Expression::Swizzle {
size,

View File

@@ -2078,6 +2078,19 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
}
write!(self.out, ")")?;
}
Expression::Splat { size, value } => {
// hlsl is not supported one value constructor
// if we write, for example, int4(0), dxc returns error:
// error: too few elements in vector initialization (expected 4 elements, have 1)
let number_of_components = match size {
crate::VectorSize::Bi => "xx",
crate::VectorSize::Tri => "xxx",
crate::VectorSize::Quad => "xxxx",
};
write!(self.out, "(")?;
write_expression(self, value)?;
write!(self.out, ").{number_of_components}")?
}
_ => unreachable!(),
}
@@ -2135,7 +2148,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Expression::Literal(_)
| Expression::Constant(_)
| Expression::ZeroValue(_)
| Expression::Compose { .. } => {
| Expression::Compose { .. }
| Expression::Splat { .. } => {
self.write_possibly_const_expression(
module,
expr,
@@ -2423,7 +2437,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
if let Some(offset) = offset {
write!(self.out, ", ")?;
write!(self.out, "int2(")?; // work around https://github.com/microsoft/DirectXShaderCompiler/issues/5082#issuecomment-1540147807
self.write_const_expression(module, offset)?;
write!(self.out, ")")?;
}
write!(self.out, ")")?;
@@ -3154,19 +3170,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
self.write_expr(module, argument, func_ctx)?;
write!(self.out, ")")?
}
Expression::Splat { size, value } => {
// hlsl is not supported one value constructor
// if we write, for example, int4(0), dxc returns error:
// error: too few elements in vector initialization (expected 4 elements, have 1)
let number_of_components = match size {
crate::VectorSize::Bi => "xx",
crate::VectorSize::Tri => "xxx",
crate::VectorSize::Quad => "xxxx",
};
write!(self.out, "(")?;
self.write_expr(module, value, func_ctx)?;
write!(self.out, ").{number_of_components}")?
}
Expression::Select {
condition,
accept,

View File

@@ -501,6 +501,7 @@ struct ExpressionContext<'a> {
origin: FunctionOrigin,
info: &'a valid::FunctionInfo,
module: &'a crate::Module,
mod_info: &'a valid::ModuleInfo,
pipeline_options: &'a PipelineOptions,
policies: index::BoundsCheckPolicies,
@@ -571,7 +572,6 @@ impl<'a> ExpressionContext<'a> {
struct StatementContext<'a> {
expression: ExpressionContext<'a>,
mod_info: &'a valid::ModuleInfo,
result_struct: Option<&'a str>,
}
@@ -604,25 +604,26 @@ impl<W: Write> Writer<W> {
parameters: impl Iterator<Item = Handle<crate::Expression>>,
context: &ExpressionContext,
) -> BackendResult {
self.put_call_parameters_impl(parameters, |writer, expr| {
self.put_call_parameters_impl(parameters, context, |writer, context, expr| {
writer.put_expression(expr, context, true)
})
}
fn put_call_parameters_impl<E>(
fn put_call_parameters_impl<C, E>(
&mut self,
parameters: impl Iterator<Item = Handle<crate::Expression>>,
ctx: &C,
put_expression: E,
) -> BackendResult
where
E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
{
write!(self.out, "(")?;
for (i, handle) in parameters.enumerate() {
if i != 0 {
write!(self.out, ", ")?;
}
put_expression(self, handle)?;
put_expression(self, ctx, handle)?;
}
write!(self.out, ")")?;
Ok(())
@@ -1213,24 +1214,33 @@ impl<W: Write> Writer<W> {
&mut self,
expr_handle: Handle<crate::Expression>,
module: &crate::Module,
mod_info: &valid::ModuleInfo,
) -> BackendResult {
self.put_possibly_const_expression(
expr_handle,
&module.const_expressions,
module,
|writer, expr| writer.put_const_expression(expr, module),
mod_info,
&(module, mod_info),
|&(_, mod_info), expr| &mod_info[expr],
|writer, &(module, _), expr| writer.put_const_expression(expr, module, mod_info),
)
}
fn put_possibly_const_expression<E>(
#[allow(clippy::too_many_arguments)]
fn put_possibly_const_expression<C, I, E>(
&mut self,
expr_handle: Handle<crate::Expression>,
expressions: &crate::Arena<crate::Expression>,
module: &crate::Module,
mod_info: &valid::ModuleInfo,
ctx: &C,
get_expr_ty: I,
put_expression: E,
) -> BackendResult
where
E: Fn(&mut Self, Handle<crate::Expression>) -> BackendResult,
I: Fn(&C, Handle<crate::Expression>) -> &TypeResolution,
E: Fn(&mut Self, &C, Handle<crate::Expression>) -> BackendResult,
{
match expressions[expr_handle] {
crate::Expression::Literal(literal) => match literal {
@@ -1263,7 +1273,7 @@ impl<W: Write> Writer<W> {
if constant.name.is_some() {
write!(self.out, "{}", self.names[&NameKey::Constant(handle)])?;
} else {
self.put_const_expression(constant.init, module)?;
self.put_const_expression(constant.init, module, mod_info)?;
}
}
crate::Expression::ZeroValue(ty) => {
@@ -1291,7 +1301,11 @@ impl<W: Write> Writer<W> {
crate::TypeInner::Scalar { .. }
| crate::TypeInner::Vector { .. }
| crate::TypeInner::Matrix { .. } => {
self.put_call_parameters_impl(components.iter().copied(), put_expression)?;
self.put_call_parameters_impl(
components.iter().copied(),
ctx,
put_expression,
)?;
}
crate::TypeInner::Array { .. } | crate::TypeInner::Struct { .. } => {
write!(self.out, " {{")?;
@@ -1303,13 +1317,23 @@ impl<W: Write> Writer<W> {
if self.struct_member_pads.contains(&(ty, index as u32)) {
write!(self.out, "{{}}, ")?;
}
put_expression(self, component)?;
put_expression(self, ctx, component)?;
}
write!(self.out, "}}")?;
}
_ => return Err(Error::UnsupportedCompose(ty)),
}
}
crate::Expression::Splat { size, value } => {
let scalar_kind = match *get_expr_ty(ctx, value).inner_with(&module.types) {
crate::TypeInner::Scalar { kind, .. } => kind,
_ => return Err(Error::Validation),
};
put_numeric_type(&mut self.out, scalar_kind, &[size])?;
write!(self.out, "(")?;
put_expression(self, ctx, value)?;
write!(self.out, ")")?;
}
_ => unreachable!(),
}
@@ -1350,12 +1374,16 @@ impl<W: Write> Writer<W> {
crate::Expression::Literal(_)
| crate::Expression::Constant(_)
| crate::Expression::ZeroValue(_)
| crate::Expression::Compose { .. } => {
| crate::Expression::Compose { .. }
| crate::Expression::Splat { .. } => {
self.put_possibly_const_expression(
expr_handle,
&context.function.expressions,
context.module,
|writer, expr| writer.put_expression(expr, context, true),
context.mod_info,
context,
|context, expr: Handle<crate::Expression>| &context.info[expr].ty,
|writer, context, expr| writer.put_expression(expr, context, true),
)?;
}
crate::Expression::Access { base, .. }
@@ -1385,16 +1413,6 @@ impl<W: Write> Writer<W> {
self.put_access_chain(expr_handle, policy, context)?;
}
}
crate::Expression::Splat { size, value } => {
let scalar_kind = match *context.resolve_type(value) {
crate::TypeInner::Scalar { kind, .. } => kind,
_ => return Err(Error::Validation),
};
put_numeric_type(&mut self.out, scalar_kind, &[size])?;
write!(self.out, "(")?;
self.put_expression(value, context, true)?;
write!(self.out, ")")?;
}
crate::Expression::Swizzle {
size,
vector,
@@ -1469,7 +1487,7 @@ impl<W: Write> Writer<W> {
if let Some(offset) = offset {
write!(self.out, ", ")?;
self.put_const_expression(offset, context.module)?;
self.put_const_expression(offset, context.module, context.mod_info)?;
}
match gather {
@@ -2792,7 +2810,7 @@ impl<W: Write> Writer<W> {
}
// follow-up with any global resources used
let mut separate = !arguments.is_empty();
let fun_info = &context.mod_info[function];
let fun_info = &context.expression.mod_info[function];
let mut supports_array_length = false;
for (handle, var) in context.expression.module.global_variables.iter() {
if fun_info[handle].is_empty() {
@@ -3131,7 +3149,7 @@ impl<W: Write> Writer<W> {
};
self.write_type_defs(module)?;
self.write_global_constants(module)?;
self.write_global_constants(module, info)?;
self.write_functions(module, info, options, pipeline_options)
}
@@ -3338,7 +3356,11 @@ impl<W: Write> Writer<W> {
}
/// Writes all named constants
fn write_global_constants(&mut self, module: &crate::Module) -> BackendResult {
fn write_global_constants(
&mut self,
module: &crate::Module,
mod_info: &valid::ModuleInfo,
) -> BackendResult {
let constants = module.constants.iter().filter(|&(_, c)| c.name.is_some());
for (handle, constant) in constants {
@@ -3352,7 +3374,7 @@ impl<W: Write> Writer<W> {
};
let name = &self.names[&NameKey::Constant(handle)];
write!(self.out, "constant {ty_name} {name} = ")?;
self.put_const_expression(constant.init, module)?;
self.put_const_expression(constant.init, module, mod_info)?;
writeln!(self.out, ";")?;
}
@@ -3549,7 +3571,7 @@ impl<W: Write> Writer<W> {
match local.init {
Some(value) => {
write!(self.out, " = ")?;
self.put_const_expression(value, module)?;
self.put_const_expression(value, module, mod_info)?;
}
None => {
write!(self.out, " = {{}}")?;
@@ -3569,9 +3591,9 @@ impl<W: Write> Writer<W> {
policies: options.bounds_check_policies,
guarded_indices,
module,
mod_info,
pipeline_options,
},
mod_info,
result_struct: None,
};
self.named_expressions.clear();
@@ -3958,7 +3980,7 @@ impl<W: Write> Writer<W> {
}
if let Some(value) = var.init {
write!(self.out, " = ")?;
self.put_const_expression(value, module)?;
self.put_const_expression(value, module, mod_info)?;
}
writeln!(self.out)?;
}
@@ -4014,7 +4036,7 @@ impl<W: Write> Writer<W> {
match var.init {
Some(value) => {
write!(self.out, " = ")?;
self.put_const_expression(value, module)?;
self.put_const_expression(value, module, mod_info)?;
writeln!(self.out, ";")?;
}
None => {
@@ -4106,7 +4128,7 @@ impl<W: Write> Writer<W> {
match local.init {
Some(value) => {
write!(self.out, " = ")?;
self.put_const_expression(value, module)?;
self.put_const_expression(value, module, mod_info)?;
}
None => {
write!(self.out, " = {{}}")?;
@@ -4126,9 +4148,9 @@ impl<W: Write> Writer<W> {
policies: options.bounds_check_policies,
guarded_indices,
module,
mod_info,
pipeline_options,
},
mod_info,
result_struct: Some(&stage_out_name),
};
self.named_expressions.clear();

View File

@@ -198,11 +198,15 @@ impl Writer {
}
}
pub(super) fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word {
let lookup_ty = match *tr {
pub(super) fn get_expression_lookup_type(&mut self, tr: &TypeResolution) -> LookupType {
match *tr {
TypeResolution::Handle(ty_handle) => LookupType::Handle(ty_handle),
TypeResolution::Value(ref inner) => LookupType::Local(make_local(inner).unwrap()),
};
}
}
pub(super) fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Word {
let lookup_ty = self.get_expression_lookup_type(tr);
self.get_type_id(lookup_ty)
}
@@ -1242,6 +1246,7 @@ impl Writer {
&mut self,
handle: Handle<crate::Expression>,
ir_module: &crate::Module,
mod_info: &ModuleInfo,
) -> Result<Word, Error> {
let id = match ir_module.const_expressions[handle] {
crate::Expression::Literal(literal) => self.get_constant_scalar(literal),
@@ -1260,6 +1265,14 @@ impl Writer {
.collect();
self.get_constant_composite(LookupType::Handle(ty), component_ids.as_slice())
}
crate::Expression::Splat { size, value } => {
let value_id = self.constant_ids[value.index()];
let component_ids = &[value_id; 4][..size as usize];
let ty = self.get_expression_lookup_type(&mod_info[handle]);
self.get_constant_composite(ty, component_ids)
}
_ => unreachable!(),
};
@@ -1878,7 +1891,7 @@ impl Writer {
self.constant_ids
.resize(ir_module.const_expressions.len(), 0);
for (handle, _) in ir_module.const_expressions.iter() {
self.write_constant_expr(handle, ir_module)?;
self.write_constant_expr(handle, ir_module, mod_info)?;
}
debug_assert!(self.constant_ids.iter().all(|&id| id != 0));

View File

@@ -1795,7 +1795,7 @@ impl MacroCall {
true => {
let offset_arg = args[num_args];
num_args += 1;
match ctx.solve_constant(offset_arg, meta) {
match ctx.eval_constant(offset_arg, meta) {
Ok(v) => Some(v),
Err(e) => {
frontend.errors.push(e);

View File

@@ -519,7 +519,7 @@ impl<'a> Context<'a> {
// Don't try to generate `AccessIndex` if in a LHS position, since it
// wouldn't produce a pointer.
ExprPos::Lhs => None,
_ => self.solve_constant(index, index_meta).ok(),
_ => self.eval_constant(index, index_meta).ok(),
};
let base = self

View File

@@ -1,5 +1,5 @@
use super::{constants::ConstantSolvingError, token::TokenValue};
use crate::Span;
use super::token::TokenValue;
use crate::{proc::ConstantEvaluatorError, Span};
use pp_rs::token::PreprocessorError;
use std::borrow::Cow;
use thiserror::Error;
@@ -116,8 +116,8 @@ pub enum ErrorKind {
InternalError(&'static str),
}
impl From<ConstantSolvingError> for ErrorKind {
fn from(err: ConstantSolvingError) -> Self {
impl From<ConstantEvaluatorError> for ErrorKind {
fn from(err: ConstantEvaluatorError) -> Self {
ErrorKind::SemanticError(err.to_string().into())
}
}

View File

@@ -22,7 +22,6 @@ use parser::ParsingContext;
mod ast;
mod builtins;
mod constants;
mod context;
mod error;
mod functions;

View File

@@ -226,7 +226,7 @@ impl<'source> ParsingContext<'source> {
let expr = self.parse_conditional(frontend, &mut ctx, &mut stmt_ctx, None)?;
let (root, meta) = ctx.lower_expect(stmt_ctx, frontend, expr, ExprPos::Rhs)?;
Ok((ctx.solve_constant(root, meta)?, meta))
Ok((ctx.eval_constant(root, meta)?, meta))
}
}

View File

@@ -241,7 +241,7 @@ impl<'source> ParsingContext<'source> {
let is_const = ctx.qualifiers.storage.0 == StorageQualifier::Const;
let maybe_const_expr = if ctx.external {
if let Some((root, meta)) = init {
match ctx.ctx.solve_constant(root, meta) {
match ctx.ctx.eval_constant(root, meta) {
Ok(res) => Some(res),
// If the declaration is external (global scope) and is constant qualified
// then the initializer must be a constant expression

View File

@@ -187,11 +187,8 @@ impl<'source> ParsingContext<'source> {
TokenValue::Case => {
self.bump(frontend)?;
let mut stmt = ctx.stmt_ctx();
let expr = self.parse_expression(frontend, ctx, &mut stmt)?;
let (root, meta) =
ctx.lower_expect(stmt, frontend, expr, ExprPos::Rhs)?;
let const_expr = ctx.solve_constant(root, meta)?;
let (const_expr, meta) =
self.parse_constant_expression(frontend, ctx.module)?;
match ctx.module.const_expressions[const_expr] {
Expression::Literal(Literal::I32(value)) => match uint {

View File

@@ -1,4 +1,4 @@
use super::{constants::ConstantSolver, context::Context, Error, ErrorKind, Result, Span};
use super::{context::Context, Error, ErrorKind, Result, Span};
use crate::{
proc::ResolveContext, Bytes, Expression, Handle, ImageClass, ImageDimension, ScalarKind, Type,
TypeInner, VectorSize,
@@ -305,19 +305,28 @@ impl Context<'_> {
})
}
pub(crate) fn solve_constant(
pub(crate) fn eval_constant(
&mut self,
root: Handle<Expression>,
meta: Span,
) -> Result<Handle<Expression>> {
let mut solver = ConstantSolver {
let mut solver = crate::proc::ConstantEvaluator {
types: &mut self.module.types,
expressions: &self.expressions,
expressions: &mut self.module.const_expressions,
constants: &mut self.module.constants,
const_expressions: &mut self.module.const_expressions,
const_expressions: Some(&self.expressions),
append: None::<
Box<
dyn FnMut(
&mut crate::Arena<Expression>,
Expression,
Span,
) -> Handle<Expression>,
>,
>,
};
solver.solve(root).map_err(|e| Error {
solver.eval(root).map_err(|e| Error {
kind: e.into(),
meta,
})

View File

@@ -34,6 +34,9 @@ impl Emitter {
}
self.start_len = Some(arena.len());
}
const fn is_running(&self) -> bool {
self.start_len.is_some()
}
#[must_use]
fn finish(
&mut self,

View File

@@ -1,5 +1,5 @@
use crate::front::wgsl::parse::lexer::Token;
use crate::proc::{Alignment, ResolveError};
use crate::proc::{Alignment, ConstantEvaluatorError, ResolveError};
use crate::{SourceLocation, Span};
use codespan_reporting::diagnostic::{Diagnostic, Label};
use codespan_reporting::files::SimpleFile;
@@ -98,8 +98,6 @@ impl std::error::Error for ParseError {
pub enum ExpectedToken<'a> {
Token(Token<'a>),
Identifier,
Number,
Integer,
/// Expected: constant, parenthesized expression, identifier
PrimaryExpression,
/// Expected: assignment, increment/decrement expression
@@ -141,10 +139,6 @@ pub enum Error<'a> {
UnexpectedComponents(Span),
UnexpectedOperationInConstContext(Span),
BadNumber(Span, NumberError),
/// A negative signed integer literal where both signed and unsigned,
/// but only non-negative literals are allowed.
NegativeInt(Span),
BadU32Constant(Span),
BadMatrixScalarKind(Span, crate::ScalarKind, u8),
BadAccessor(Span),
BadTexture(Span),
@@ -240,9 +234,11 @@ pub enum Error<'a> {
FunctionReturnsVoid(Span),
InvalidWorkGroupUniformLoad(Span),
Other,
ExpectedArraySize(Span),
NonPositiveArrayLength(Span),
ExpectedConstExprConcreteIntegerScalar(Span),
ExpectedNonNegative(Span),
ExpectedPositiveArrayLength(Span),
MissingWorkgroupSize(Span),
ConstantEvaluatorError(ConstantEvaluatorError, Span),
}
impl<'a> Error<'a> {
@@ -271,8 +267,6 @@ impl<'a> Error<'a> {
}
}
ExpectedToken::Identifier => "identifier".to_string(),
ExpectedToken::Number => "32-bit signed integer literal".to_string(),
ExpectedToken::Integer => "unsigned/signed integer literal".to_string(),
ExpectedToken::PrimaryExpression => "expression".to_string(),
ExpectedToken::Assignment => "assignment or increment/decrement".to_string(),
ExpectedToken::SwitchItem => "switch item ('case' or 'default') or a closing curly bracket to signify the end of the switch statement ('}')".to_string(),
@@ -306,22 +300,6 @@ impl<'a> Error<'a> {
labels: vec![(bad_span, err.to_string().into())],
notes: vec![],
},
Error::NegativeInt(bad_span) => ParseError {
message: format!(
"expected non-negative integer literal, found `{}`",
&source[bad_span],
),
labels: vec![(bad_span, "expected non-negative integer".into())],
notes: vec![],
},
Error::BadU32Constant(bad_span) => ParseError {
message: format!(
"expected unsigned integer constant expression, found `{}`",
&source[bad_span],
),
labels: vec![(bad_span, "expected unsigned integer".into())],
notes: vec![],
},
Error::BadMatrixScalarKind(span, kind, width) => ParseError {
message: format!(
"matrix scalar type must be floating-point, but found `{}`",
@@ -694,15 +672,24 @@ impl<'a> Error<'a> {
labels: vec![],
notes: vec![],
},
Error::ExpectedArraySize(span) => ParseError {
message: "array element count must resolve to an integer scalar (u32 or i32)"
.to_string(),
labels: vec![(span, "must resolve to u32/i32".into())],
Error::ExpectedConstExprConcreteIntegerScalar(span) => ParseError {
message: "must be a const-expression that resolves to a concrete integer scalar (u32 or i32)".to_string(),
labels: vec![(span, "must resolve to u32 or i32".into())],
notes: vec![],
},
Error::NonPositiveArrayLength(span) => ParseError {
message: "array element count must be greater than zero".to_string(),
labels: vec![(span, "must be greater than zero".into())],
Error::ExpectedNonNegative(span) => ParseError {
message: "must be non-negative (>= 0)".to_string(),
labels: vec![(span, "must be non-negative".into())],
notes: vec![],
},
Error::ExpectedPositiveArrayLength(span) => ParseError {
message: "array element count must be positive (> 0)".to_string(),
labels: vec![(span, "must be positive".into())],
notes: vec![],
},
Error::ConstantEvaluatorError(ref e, span) => ParseError {
message: e.to_string(),
labels: vec![(span, "see msg".into())],
notes: vec![],
},
Error::MissingWorkgroupSize(span) => ParseError {

View File

@@ -197,7 +197,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
// Empty constructor
(Components::None, dst_ty) => match dst_ty {
ConcreteConstructor::Type(ty, _) => {
return Ok(ctx.interrupt_emitter(crate::Expression::ZeroValue(ty), span))
return ctx.append_expression(crate::Expression::ZeroValue(ty), span)
}
_ => return Err(Error::TypeNotInferrable(ty_span)),
},
@@ -408,7 +408,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
Default::default(),
)
})
.collect();
.collect::<Result<Vec<_>, _>>()?;
let ty = ctx.ensure_type_exists(crate::TypeInner::Matrix {
columns,
@@ -523,7 +523,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
_ => return Err(Error::TypeNotConstructible(ty_span)),
};
let expr = ctx.append_expression(expr, span);
let expr = ctx.append_expression(expr, span)?;
Ok(expr)
}
@@ -585,7 +585,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let size = match size {
ast::ArraySize::Constant(expr) => {
let const_expr = self.expression(expr, ctx.as_const())?;
crate::ArraySize::Constant(ctx.array_length(const_expr)?)
crate::ArraySize::Constant(ctx.as_const().array_length(const_expr)?)
}
ast::ArraySize::Dynamic => crate::ArraySize::Dynamic,
};

View File

@@ -5,7 +5,9 @@ use crate::front::wgsl::index::Index;
use crate::front::wgsl::parse::number::Number;
use crate::front::wgsl::parse::{ast, conv};
use crate::front::{Emitter, Typifier};
use crate::proc::{ensure_block_returns, Alignment, Layouter, ResolveContext, TypeResolution};
use crate::proc::{
ensure_block_returns, Alignment, ConstantEvaluator, Layouter, ResolveContext, TypeResolution,
};
use crate::{Arena, FastHashMap, FastIndexMap, Handle, Span};
mod construction;
@@ -327,17 +329,66 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
&mut self,
expr: crate::Expression,
span: Span,
) -> Handle<crate::Expression> {
) -> Result<Handle<crate::Expression>, Error<'source>> {
match self.expr_type {
ExpressionContextType::Runtime(ref mut ctx) => ctx.naga_expressions.append(expr, span),
ExpressionContextType::Constant => self.module.const_expressions.append(expr, span),
ExpressionContextType::Runtime(ref mut rctx) => {
let mut eval = ConstantEvaluator {
types: &mut self.module.types,
constants: &self.module.constants,
expressions: rctx.naga_expressions,
const_expressions: Some(&self.module.const_expressions),
append: Some(
|arena: &mut Arena<crate::Expression>, expr: crate::Expression, span| {
let is_running = rctx.emitter.is_running();
let needs_pre_emit = expr.needs_pre_emit();
if is_running && needs_pre_emit {
rctx.block.extend(rctx.emitter.finish(arena));
}
let h = arena.append(expr, span);
if is_running && needs_pre_emit {
rctx.emitter.start(arena);
}
h
},
),
};
match eval.try_eval_and_append(&expr, span) {
Ok(expr) => Ok(expr),
Err(_) => Ok(rctx.naga_expressions.append(expr, span)),
}
}
ExpressionContextType::Constant => {
let mut eval = ConstantEvaluator {
types: &mut self.module.types,
constants: &self.module.constants,
expressions: &mut self.module.const_expressions,
const_expressions: None,
append: None::<
Box<
dyn FnMut(
&mut Arena<crate::Expression>,
crate::Expression,
Span,
) -> Handle<crate::Expression>,
>,
>,
};
eval.try_eval_and_append(&expr, span)
.map_err(|e| Error::ConstantEvaluatorError(e, span))
}
}
}
fn get_expression(&self, handle: Handle<crate::Expression>) -> &crate::Expression {
fn const_access(&self, handle: Handle<crate::Expression>) -> Option<u32> {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => &ctx.naga_expressions[handle],
ExpressionContextType::Constant => &self.module.const_expressions[handle],
ExpressionContextType::Runtime(ref ctx) => self
.module
.to_ctx()
.eval_expr_to_u32_from(handle, ctx.naga_expressions)
.ok(),
ExpressionContextType::Constant => self.module.to_ctx().eval_expr_to_u32(handle).ok(),
}
}
@@ -366,34 +417,52 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
}
fn array_length(
&self,
&mut self,
const_expr: Handle<crate::Expression>,
) -> Result<NonZeroU32, Error<'source>> {
let span = self.module.const_expressions.get_span(const_expr);
let len = self
.module
.to_ctx()
.eval_expr_to_u32(const_expr)
.map_err(|err| match err {
crate::proc::U32EvalError::NonConst => Error::ExpectedArraySize(span),
crate::proc::U32EvalError::Negative => Error::NonPositiveArrayLength(span),
})?;
NonZeroU32::new(len).ok_or(Error::NonPositiveArrayLength(span))
match self.expr_type {
ExpressionContextType::Runtime(_) => {
unreachable!()
}
ExpressionContextType::Constant => {
let span = self.module.const_expressions.get_span(const_expr);
let len =
self.module
.to_ctx()
.eval_expr_to_u32(const_expr)
.map_err(|err| match err {
crate::proc::U32EvalError::NonConst => {
Error::ExpectedConstExprConcreteIntegerScalar(span)
}
crate::proc::U32EvalError::Negative => {
Error::ExpectedPositiveArrayLength(span)
}
})?;
NonZeroU32::new(len).ok_or(Error::ExpectedPositiveArrayLength(span))
}
}
}
fn gather_component(
&self,
&mut self,
expr: Handle<crate::Expression>,
gather_span: Span,
) -> Result<crate::SwizzleComponent, Error<'source>> {
match self.expr_type {
ExpressionContextType::Runtime(ref ctx) => {
let expr_span = ctx.naga_expressions.get_span(expr);
ExpressionContextType::Runtime(ref rctx) => {
let expr_span = rctx.naga_expressions.get_span(expr);
let index = self
.module
.to_ctx()
.eval_expr_to_u32_from(expr, ctx.naga_expressions)
.map_err(|_| Error::InvalidGatherComponent(expr_span))?;
.eval_expr_to_u32_from(expr, rctx.naga_expressions)
.map_err(|err| match err {
crate::proc::U32EvalError::NonConst => {
Error::ExpectedConstExprConcreteIntegerScalar(expr_span)
}
crate::proc::U32EvalError::Negative => {
Error::ExpectedNonNegative(expr_span)
}
})?;
crate::SwizzleComponent::XYZW
.get(index as usize)
.copied()
@@ -543,13 +612,13 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
value: *right,
},
self.get_expression_span(*right),
);
)?;
}
(&crate::TypeInner::Scalar { .. }, &crate::TypeInner::Vector { size, .. }) => {
*left = self.append_expression(
crate::Expression::Splat { size, value: *left },
self.get_expression_span(*left),
);
)?;
}
_ => {}
}
@@ -566,7 +635,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
&mut self,
expression: crate::Expression,
span: Span,
) -> Handle<crate::Expression> {
) -> Result<Handle<crate::Expression>, Error<'source>> {
match self.expr_type {
ExpressionContextType::Runtime(ref mut rctx) => {
rctx.block
@@ -588,7 +657,10 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
///
/// If `expr` is has type `ref<SC, T, A>`, perform a load to produce a value of type
/// `T`. Otherwise, return `expr` unchanged.
fn apply_load_rule(&mut self, expr: TypedExpression) -> Handle<crate::Expression> {
fn apply_load_rule(
&mut self,
expr: TypedExpression,
) -> Result<Handle<crate::Expression>, Error<'source>> {
if expr.is_reference {
let load = crate::Expression::Load {
pointer: expr.handle,
@@ -596,7 +668,7 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
let span = self.get_expression_span(expr.handle);
self.append_expression(load, span)
} else {
expr.handle
Ok(expr.handle)
}
}
@@ -831,11 +903,20 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.map(|init| self.expression(init, ctx.as_const()))
.transpose()?;
let binding = if let Some(ref binding) = v.binding {
Some(crate::ResourceBinding {
group: self.const_u32(binding.group, ctx.as_const())?.0,
binding: self.const_u32(binding.binding, ctx.as_const())?.0,
})
} else {
None
};
let handle = ctx.module.global_variables.append(
crate::GlobalVariable {
name: Some(v.name.name.to_string()),
space: v.space,
binding: v.binding.clone(),
binding,
ty,
init,
},
@@ -930,7 +1011,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
Ok(crate::FunctionArgument {
name: Some(arg.name.name.to_string()),
ty,
binding: self.interpolate_default(&arg.binding, ty, ctx.reborrow()),
binding: self.binding(&arg.binding, ty, ctx.reborrow())?,
})
})
.collect::<Result<Vec<_>, _>>()?;
@@ -939,11 +1020,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.result
.as_ref()
.map(|res| {
self.resolve_ast_type(res.ty, ctx.reborrow())
.map(|ty| crate::FunctionResult {
ty,
binding: self.interpolate_default(&res.binding, ty, ctx.reborrow()),
})
let ty = self.resolve_ast_type(res.ty, ctx.reborrow())?;
Ok(crate::FunctionResult {
ty,
binding: self.binding(&res.binding, ty, ctx.reborrow())?,
})
})
.transpose()?;
@@ -980,11 +1061,24 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
};
if let Some(ref entry) = f.entry_point {
let workgroup_size = if let Some(workgroup_size) = entry.workgroup_size {
// TODO: replace with try_map once stabilized
let mut workgroup_size_out = [1; 3];
for (i, size) in workgroup_size.into_iter().enumerate() {
if let Some(size_expr) = size {
workgroup_size_out[i] = self.const_u32(size_expr, ctx.as_const())?.0;
}
}
workgroup_size_out
} else {
[0; 3]
};
ctx.module.entry_points.push(crate::EntryPoint {
name: f.name.name.to_string(),
stage: entry.stage,
early_depth_test: entry.early_depth_test,
workgroup_size: entry.workgroup_size.unwrap_or([0, 0, 0]),
workgroup_size,
function,
});
Ok(LoweredGlobalDecl::EntryPoint)
@@ -1106,9 +1200,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
stmt.span,
);
let handle = ctx
.as_expression(block, &mut emitter)
.interrupt_emitter(crate::Expression::LocalVariable(var), Span::UNDEFINED);
let handle = ctx.as_expression(block, &mut emitter).interrupt_emitter(
crate::Expression::LocalVariable(var),
Span::UNDEFINED,
)?;
block.extend(emitter.finish(ctx.naga_expressions));
ctx.local_table.insert(
v.handle,
@@ -1168,19 +1263,24 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.map(|case| {
Ok(crate::SwitchCase {
value: match case.value {
ast::SwitchValue::I32(value) if !uint => {
crate::SwitchValue::I32(value)
}
ast::SwitchValue::U32(value) if uint => {
crate::SwitchValue::U32(value)
ast::SwitchValue::Expr(expr) => {
let expr = self.expression(expr, ctx.as_global().as_const())?;
match ctx.module.to_ctx().eval_expr_to_literal(expr) {
Some(crate::Literal::I32(value)) if !uint => {
crate::SwitchValue::I32(value)
}
Some(crate::Literal::U32(value)) if uint => {
crate::SwitchValue::U32(value)
}
_ => {
return Err(Error::InvalidSwitchValue {
uint,
span: ctx.module.const_expressions.get_span(expr),
});
}
}
}
ast::SwitchValue::Default => crate::SwitchValue::Default,
_ => {
return Err(Error::InvalidSwitchValue {
uint,
span: case.value_span,
});
}
},
body: self.block(&case.body, ctx.reborrow())?,
fall_through: case.fall_through,
@@ -1261,7 +1361,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let value = match op {
Some(op) => {
let mut ctx = ctx.as_expression(block, &mut emitter);
let mut left = ctx.apply_load_rule(expr);
let mut left = ctx.apply_load_rule(expr)?;
ctx.binary_op_splat(op, &mut left, &mut value)?;
ctx.append_expression(
crate::Expression::Binary {
@@ -1270,7 +1370,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
right: value,
},
stmt.span,
)
)?
}
None => value,
};
@@ -1319,7 +1419,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
};
let right =
ectx.interrupt_emitter(crate::Expression::Literal(literal), Span::UNDEFINED);
ectx.interrupt_emitter(crate::Expression::Literal(literal), Span::UNDEFINED)?;
let rctx = ectx.runtime_expression_ctx(stmt.span)?;
let left = rctx.naga_expressions.append(
crate::Expression::Load {
@@ -1358,7 +1458,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
mut ctx: ExpressionContext<'source, '_, '_>,
) -> Result<Handle<crate::Expression>, Error<'source>> {
let expr = self.expression_for_reference(expr, ctx.reborrow())?;
Ok(ctx.apply_load_rule(expr))
ctx.apply_load_rule(expr)
}
fn expression_for_reference(
@@ -1380,7 +1480,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
ast::Literal::Bool(b) => crate::Literal::Bool(b),
};
let handle = ctx.interrupt_emitter(crate::Expression::Literal(literal), span);
let handle = ctx.interrupt_emitter(crate::Expression::Literal(literal), span)?;
return Ok(TypedExpression::non_reference(handle));
}
ast::Expression::Ident(ast::IdentExpr::Local(local)) => {
@@ -1403,7 +1503,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
};
let handle = ctx.interrupt_emitter(expr, span);
let handle = ctx.interrupt_emitter(expr, span)?;
Ok(TypedExpression {
handle,
is_reference,
@@ -1483,16 +1583,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
));
}
if let crate::Expression::Literal(lit) = *ctx.get_expression(index) {
let span = ctx.get_expression_span(index);
let index = match lit {
crate::Literal::U32(index) => Ok(index),
crate::Literal::I32(index) => {
u32::try_from(index).map_err(|_| Error::BadU32Constant(span))
}
_ => Err(Error::BadU32Constant(span)),
}?;
if let Some(index) = ctx.const_access(index) {
(
crate::Expression::AccessIndex {
base: expr.handle,
@@ -1572,7 +1663,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let vector = ctx.apply_load_rule(TypedExpression {
handle,
is_reference,
});
})?;
(
crate::Expression::Swizzle {
@@ -1625,7 +1716,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
};
let handle = ctx.append_expression(expr, span);
let handle = ctx.append_expression(expr, span)?;
Ok(TypedExpression {
handle,
is_reference,
@@ -1921,7 +2012,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
_ => return Err(Error::InvalidAtomicOperandType(value_span)),
};
let result = ctx.interrupt_emitter(expression, span);
let result = ctx.interrupt_emitter(expression, span)?;
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block.push(
crate::Statement::Atomic {
@@ -1973,7 +2064,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let result = ctx.interrupt_emitter(
crate::Expression::WorkGroupUniformLoadResult { ty: result_ty },
span,
);
)?;
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block.push(
crate::Statement::WorkGroupUniformLoad { pointer, result },
@@ -2126,8 +2217,10 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let query = self.ray_query_pointer(args.next()?, ctx.reborrow())?;
args.finish()?;
let result = ctx
.interrupt_emitter(crate::Expression::RayQueryProceedResult, span);
let result = ctx.interrupt_emitter(
crate::Expression::RayQueryProceedResult,
span,
)?;
let fun = crate::RayQueryFunction::Proceed { result };
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block
@@ -2161,7 +2254,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
};
let expr = ctx.append_expression(expr, span);
let expr = ctx.append_expression(expr, span)?;
Ok(Some(expr))
}
}
@@ -2214,7 +2307,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
comparison: false,
},
span,
);
)?;
let rctx = ctx.runtime_expression_ctx(span)?;
rctx.block.push(
crate::Statement::Atomic {
@@ -2367,7 +2460,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let member_min_size = self.layouter[ty].size;
let member_min_alignment = self.layouter[ty].alignment;
let member_size = if let Some((size, span)) = member.size {
let member_size = if let Some(size_expr) = member.size {
let (size, span) = self.const_u32(size_expr, ctx.as_const())?;
if size < member_min_size {
return Err(Error::SizeAttributeTooLow(span, member_min_size));
} else {
@@ -2377,7 +2471,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
member_min_size
};
let member_alignment = if let Some((align, span)) = member.align {
let member_alignment = if let Some(align_expr) = member.align {
let (align, span) = self.const_u32(align_expr, ctx.as_const())?;
if let Some(alignment) = Alignment::new(align) {
if alignment < member_min_alignment {
return Err(Error::AlignAttributeTooLow(span, member_min_alignment));
@@ -2391,7 +2486,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
member_min_alignment
};
let binding = self.interpolate_default(&member.binding, ty, ctx.reborrow());
let binding = self.binding(&member.binding, ty, ctx.reborrow())?;
offset = member_alignment.round_up(offset);
struct_alignment = struct_alignment.max(member_alignment);
@@ -2422,6 +2517,26 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
Ok(handle)
}
fn const_u32(
&mut self,
expr: Handle<ast::Expression<'source>>,
mut ctx: ExpressionContext<'source, '_, '_>,
) -> Result<(u32, Span), Error<'source>> {
let expr = self.expression(expr, ctx.reborrow())?;
let span = ctx.module.const_expressions.get_span(expr);
let value = ctx
.module
.to_ctx()
.eval_expr_to_u32(expr)
.map_err(|err| match err {
crate::proc::U32EvalError::NonConst => {
Error::ExpectedConstExprConcreteIntegerScalar(span)
}
crate::proc::U32EvalError::Negative => Error::ExpectedNonNegative(span),
})?;
Ok((value, span))
}
/// Return a Naga `Handle<Type>` representing the front-end type `handle`.
fn resolve_ast_type(
&mut self,
@@ -2507,18 +2622,31 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
Ok(ctx.ensure_type_exists(inner))
}
fn interpolate_default(
fn binding(
&mut self,
binding: &Option<crate::Binding>,
binding: &Option<ast::Binding<'source>>,
ty: Handle<crate::Type>,
ctx: GlobalContext<'source, '_, '_>,
) -> Option<crate::Binding> {
let mut binding = binding.clone();
if let Some(ref mut binding) = binding {
binding.apply_default_interpolation(&ctx.module.types[ty].inner);
}
binding
mut ctx: GlobalContext<'source, '_, '_>,
) -> Result<Option<crate::Binding>, Error<'source>> {
Ok(match *binding {
Some(ast::Binding::BuiltIn(b)) => Some(crate::Binding::BuiltIn(b)),
Some(ast::Binding::Location {
location,
second_blend_source,
interpolation,
sampling,
}) => {
let mut binding = crate::Binding::Location {
location: self.const_u32(location, ctx.as_const())?.0,
second_blend_source,
interpolation,
sampling,
};
binding.apply_default_interpolation(&ctx.module.types[ty].inner);
Some(binding)
}
None => None,
})
}
fn ray_query_pointer(

View File

@@ -89,21 +89,21 @@ pub enum GlobalDeclKind<'a> {
pub struct FunctionArgument<'a> {
pub name: Ident<'a>,
pub ty: Handle<Type<'a>>,
pub binding: Option<crate::Binding>,
pub binding: Option<Binding<'a>>,
pub handle: Handle<Local>,
}
#[derive(Debug)]
pub struct FunctionResult<'a> {
pub ty: Handle<Type<'a>>,
pub binding: Option<crate::Binding>,
pub binding: Option<Binding<'a>>,
}
#[derive(Debug)]
pub struct EntryPoint {
pub struct EntryPoint<'a> {
pub stage: crate::ShaderStage,
pub early_depth_test: Option<crate::EarlyDepthTest>,
pub workgroup_size: Option<[u32; 3]>,
pub workgroup_size: Option<[Option<Handle<Expression<'a>>>; 3]>,
}
#[cfg(doc)]
@@ -111,7 +111,7 @@ use crate::front::wgsl::lower::{RuntimeExpressionContext, StatementContext};
#[derive(Debug)]
pub struct Function<'a> {
pub entry_point: Option<EntryPoint>,
pub entry_point: Option<EntryPoint<'a>>,
pub name: Ident<'a>,
pub arguments: Vec<FunctionArgument<'a>>,
pub result: Option<FunctionResult<'a>>,
@@ -145,11 +145,28 @@ pub struct Function<'a> {
pub body: Block<'a>,
}
#[derive(Debug)]
pub enum Binding<'a> {
BuiltIn(crate::BuiltIn),
Location {
location: Handle<Expression<'a>>,
second_blend_source: bool,
interpolation: Option<crate::Interpolation>,
sampling: Option<crate::Sampling>,
},
}
#[derive(Debug)]
pub struct ResourceBinding<'a> {
pub group: Handle<Expression<'a>>,
pub binding: Handle<Expression<'a>>,
}
#[derive(Debug)]
pub struct GlobalVariable<'a> {
pub name: Ident<'a>,
pub space: crate::AddressSpace,
pub binding: Option<crate::ResourceBinding>,
pub binding: Option<ResourceBinding<'a>>,
pub ty: Handle<Type<'a>>,
pub init: Option<Handle<Expression<'a>>>,
}
@@ -158,9 +175,9 @@ pub struct GlobalVariable<'a> {
pub struct StructMember<'a> {
pub name: Ident<'a>,
pub ty: Handle<Type<'a>>,
pub binding: Option<crate::Binding>,
pub align: Option<(u32, Span)>,
pub size: Option<(u32, Span)>,
pub binding: Option<Binding<'a>>,
pub align: Option<Handle<Expression<'a>>>,
pub size: Option<Handle<Expression<'a>>>,
}
#[derive(Debug)]
@@ -292,16 +309,14 @@ pub enum StatementKind<'a> {
}
#[derive(Debug)]
pub enum SwitchValue {
I32(i32),
U32(u32),
pub enum SwitchValue<'a> {
Expr(Handle<Expression<'a>>),
Default,
}
#[derive(Debug)]
pub struct SwitchCase<'a> {
pub value: SwitchValue,
pub value_span: Span,
pub value: SwitchValue<'a>,
pub body: Block<'a>,
pub fall_through: bool,
}

View File

@@ -141,8 +141,8 @@ impl<T> ParsedAttribute<T> {
}
#[derive(Default)]
struct BindingParser {
location: ParsedAttribute<u32>,
struct BindingParser<'a> {
location: ParsedAttribute<Handle<ast::Expression<'a>>>,
second_blend_source: ParsedAttribute<bool>,
built_in: ParsedAttribute<crate::BuiltIn>,
interpolation: ParsedAttribute<crate::Interpolation>,
@@ -150,18 +150,20 @@ struct BindingParser {
invariant: ParsedAttribute<bool>,
}
impl BindingParser {
fn parse<'a>(
impl<'a> BindingParser<'a> {
fn parse(
&mut self,
parser: &mut Parser,
lexer: &mut Lexer<'a>,
name: &'a str,
name_span: Span,
mut ctx: ExpressionContext<'a, '_, '_>,
) -> Result<(), Error<'a>> {
match name {
"location" => {
lexer.expect(Token::Paren('('))?;
self.location
.set(Parser::non_negative_i32_literal(lexer)?, name_span)?;
.set(parser.general_expression(lexer, ctx.reborrow())?, name_span)?;
lexer.expect(Token::Paren(')'))?;
}
"builtin" => {
@@ -194,7 +196,7 @@ impl BindingParser {
Ok(())
}
fn finish<'a>(self, span: Span) -> Result<Option<crate::Binding>, Error<'a>> {
fn finish(self, span: Span) -> Result<Option<ast::Binding<'a>>, Error<'a>> {
match (
self.location.value,
self.built_in.value,
@@ -208,7 +210,7 @@ impl BindingParser {
// `apply_default_interpolation` to ensure that the interpolation and
// sampling have been explicitly specified on all vertex shader output and fragment
// shader input user bindings, so leaving them potentially `None` here is fine.
Ok(Some(crate::Binding::Location {
Ok(Some(ast::Binding::Location {
location,
interpolation,
sampling,
@@ -216,13 +218,11 @@ impl BindingParser {
}))
}
(None, Some(crate::BuiltIn::Position { .. }), None, None, invariant) => {
Ok(Some(crate::Binding::BuiltIn(crate::BuiltIn::Position {
Ok(Some(ast::Binding::BuiltIn(crate::BuiltIn::Position {
invariant,
})))
}
(None, Some(built_in), None, None, false) => {
Ok(Some(crate::Binding::BuiltIn(built_in)))
}
(None, Some(built_in), None, None, false) => Ok(Some(ast::Binding::BuiltIn(built_in))),
(_, _, _, _, _) => Err(Error::InconsistentBinding(span)),
}
}
@@ -255,41 +255,18 @@ impl Parser {
lexer.span_from(initial)
}
fn switch_value<'a>(lexer: &mut Lexer<'a>) -> Result<(ast::SwitchValue, Span), Error<'a>> {
let token_span = lexer.next();
match token_span.0 {
Token::Word("default") => Ok((ast::SwitchValue::Default, token_span.1)),
Token::Number(Ok(Number::U32(num))) => Ok((ast::SwitchValue::U32(num), token_span.1)),
Token::Number(Ok(Number::I32(num))) => Ok((ast::SwitchValue::I32(num), token_span.1)),
Token::Number(Err(e)) => Err(Error::BadNumber(token_span.1, e)),
_ => Err(Error::Unexpected(token_span.1, ExpectedToken::Integer)),
fn switch_value<'a>(
&mut self,
lexer: &mut Lexer<'a>,
mut ctx: ExpressionContext<'a, '_, '_>,
) -> Result<ast::SwitchValue<'a>, Error<'a>> {
if let Token::Word("default") = lexer.peek().0 {
let _ = lexer.next();
return Ok(ast::SwitchValue::Default);
}
}
/// Parse a non-negative signed integer literal.
/// This is for attributes like `size`, `location` and others.
fn non_negative_i32_literal<'a>(lexer: &mut Lexer<'a>) -> Result<u32, Error<'a>> {
match lexer.next() {
(Token::Number(Ok(Number::I32(num))), span) => {
u32::try_from(num).map_err(|_| Error::NegativeInt(span))
}
(Token::Number(Err(e)), span) => Err(Error::BadNumber(span, e)),
other => Err(Error::Unexpected(other.1, ExpectedToken::Number)),
}
}
/// Parse a non-negative integer literal that may be either signed or unsigned.
/// This is for the `workgroup_size` attribute and array lengths.
/// Note: these values should be no larger than [`i32::MAX`], but this is not checked here.
fn generic_non_negative_int_literal<'a>(lexer: &mut Lexer<'a>) -> Result<u32, Error<'a>> {
match lexer.next() {
(Token::Number(Ok(Number::I32(num))), span) => {
u32::try_from(num).map_err(|_| Error::NegativeInt(span))
}
(Token::Number(Ok(Number::U32(num))), _) => Ok(num),
(Token::Number(Err(e)), span) => Err(Error::BadNumber(span, e)),
other => Err(Error::Unexpected(other.1, ExpectedToken::Number)),
}
let expr = self.general_expression(lexer, ctx.reborrow())?;
Ok(ast::SwitchValue::Expr(expr))
}
/// Decide if we're looking at a construction expression, and return its
@@ -1028,17 +1005,19 @@ impl Parser {
match lexer.next_ident_with_span()? {
("size", name_span) => {
lexer.expect(Token::Paren('('))?;
let (value, span) = lexer.capture_span(Self::non_negative_i32_literal)?;
let expr = self.general_expression(lexer, ctx.reborrow())?;
lexer.expect(Token::Paren(')'))?;
size.set((value, span), name_span)?;
size.set(expr, name_span)?;
}
("align", name_span) => {
lexer.expect(Token::Paren('('))?;
let (value, span) = lexer.capture_span(Self::non_negative_i32_literal)?;
let expr = self.general_expression(lexer, ctx.reborrow())?;
lexer.expect(Token::Paren(')'))?;
align.set((value, span), name_span)?;
align.set(expr, name_span)?;
}
(word, word_span) => {
bind_parser.parse(self, lexer, word, word_span, ctx.reborrow())?
}
(word, word_span) => bind_parser.parse(lexer, word, word_span)?,
}
}
@@ -1765,19 +1744,18 @@ impl Parser {
match lexer.next() {
(Token::Word("case"), _) => {
// parse a list of values
let (value, value_span) = loop {
let (value, value_span) = Self::switch_value(lexer)?;
let value = loop {
let value = self.switch_value(lexer, ctx.reborrow())?;
if lexer.skip(Token::Separator(',')) {
if lexer.skip(Token::Separator(':')) {
break (value, value_span);
break value;
}
} else {
lexer.skip(Token::Separator(':'));
break (value, value_span);
break value;
}
cases.push(ast::SwitchCase {
value,
value_span,
body: ast::Block::default(),
fall_through: true,
});
@@ -1787,17 +1765,15 @@ impl Parser {
cases.push(ast::SwitchCase {
value,
value_span,
body,
fall_through: false,
});
}
(Token::Word("default"), value_span) => {
(Token::Word("default"), _) => {
lexer.skip(Token::Separator(':'));
let body = self.block(lexer, ctx.reborrow())?.0;
cases.push(ast::SwitchCase {
value: ast::SwitchValue::Default,
value_span,
body,
fall_through: false,
});
@@ -2059,13 +2035,14 @@ impl Parser {
fn varying_binding<'a>(
&mut self,
lexer: &mut Lexer<'a>,
) -> Result<Option<crate::Binding>, Error<'a>> {
mut ctx: ExpressionContext<'a, '_, '_>,
) -> Result<Option<ast::Binding<'a>>, Error<'a>> {
let mut bind_parser = BindingParser::default();
self.push_rule_span(Rule::Attribute, lexer);
while lexer.skip(Token::Attribute) {
let (word, span) = lexer.next_ident_with_span()?;
bind_parser.parse(lexer, word, span)?;
bind_parser.parse(self, lexer, word, span, ctx.reborrow())?;
}
let span = self.pop_rule_span(lexer);
@@ -2106,7 +2083,7 @@ impl Parser {
ExpectedToken::Token(Token::Separator(',')),
));
}
let binding = self.varying_binding(lexer)?;
let binding = self.varying_binding(lexer, ctx.reborrow())?;
let param_name = lexer.next_ident()?;
@@ -2124,7 +2101,7 @@ impl Parser {
}
// read return type
let result = if lexer.skip(Token::Arrow) && !lexer.skip(Token::Word("void")) {
let binding = self.varying_binding(lexer)?;
let binding = self.varying_binding(lexer, ctx.reborrow())?;
let ty = self.type_decl(lexer, ctx.reborrow())?;
Some(ast::FunctionResult { ty, binding })
} else {
@@ -2169,17 +2146,26 @@ impl Parser {
let (mut bind_index, mut bind_group) =
(ParsedAttribute::default(), ParsedAttribute::default());
let mut dependencies = FastIndexSet::default();
let mut ctx = ExpressionContext {
expressions: &mut out.expressions,
local_table: &mut SymbolTable::default(),
locals: &mut Arena::new(),
types: &mut out.types,
unresolved: &mut dependencies,
};
self.push_rule_span(Rule::Attribute, lexer);
while lexer.skip(Token::Attribute) {
match lexer.next_ident_with_span()? {
("binding", name_span) => {
lexer.expect(Token::Paren('('))?;
bind_index.set(Self::non_negative_i32_literal(lexer)?, name_span)?;
bind_index.set(self.general_expression(lexer, ctx.reborrow())?, name_span)?;
lexer.expect(Token::Paren(')'))?;
}
("group", name_span) => {
lexer.expect(Token::Paren('('))?;
bind_group.set(Self::non_negative_i32_literal(lexer)?, name_span)?;
bind_group.set(self.general_expression(lexer, ctx.reborrow())?, name_span)?;
lexer.expect(Token::Paren(')'))?;
}
("vertex", name_span) => {
@@ -2194,9 +2180,9 @@ impl Parser {
}
("workgroup_size", name_span) => {
lexer.expect(Token::Paren('('))?;
let mut new_workgroup_size = [1u32; 3];
let mut new_workgroup_size = [None; 3];
for (i, size) in new_workgroup_size.iter_mut().enumerate() {
*size = Self::generic_non_negative_int_literal(lexer)?;
*size = Some(self.general_expression(lexer, ctx.reborrow())?);
match lexer.next() {
(Token::Paren(')'), _) => break,
(Token::Separator(','), _) if i != 2 => (),
@@ -2228,7 +2214,7 @@ impl Parser {
let attrib_span = self.pop_rule_span(lexer);
match (bind_group.value, bind_index.value) {
(Some(group), Some(index)) => {
binding = Some(crate::ResourceBinding {
binding = Some(ast::ResourceBinding {
group,
binding: index,
});
@@ -2238,15 +2224,6 @@ impl Parser {
(None, None) => {}
}
let mut dependencies = FastIndexSet::default();
let mut ctx = ExpressionContext {
expressions: &mut out.expressions,
local_table: &mut SymbolTable::default(),
locals: &mut Arena::new(),
types: &mut out.types,
unresolved: &mut dependencies,
};
// read item
let start = lexer.start_byte_offset();
let kind = match lexer.next() {

View File

@@ -402,9 +402,8 @@ fn binary_expression_mixed_scalar_and_vector_operands() {
] {
let module = parse_str(&format!(
"
const some_vec = vec3<f32>(1.0, 1.0, 1.0);
@fragment
fn main() -> @location(0) vec4<f32> {{
fn main(@location(0) some_vec: vec3<f32>) -> @location(0) vec4<f32> {{
if (all(1.0 {operand} some_vec)) {{
return vec4(0.0);
}}
@@ -431,7 +430,12 @@ fn binary_expression_mixed_scalar_and_vector_operands() {
})
.count();
assert_eq!(found_expressions, 1);
assert_eq!(
found_expressions,
1,
"expected `{operand}` expression {} splat",
if expect_splat { "with" } else { "without" }
);
}
let module = parse_str(

View File

@@ -2,12 +2,14 @@
[`Module`](super::Module) processing functionality.
*/
mod constant_evaluator;
pub mod index;
mod layouter;
mod namer;
mod terminator;
mod typifier;
pub use constant_evaluator::{ConstantEvaluator, ConstantEvaluatorError};
pub use index::{BoundsCheckPolicies, BoundsCheckPolicy, IndexableLength, IndexableLengthError};
pub use layouter::{Alignment, LayoutError, LayoutErrorInner, Layouter, TypeLayout};
pub use namer::{EntryPointIndex, NameKey, Namer};
@@ -590,28 +592,39 @@ impl GlobalCtx<'_> {
handle: crate::Handle<crate::Expression>,
arena: &crate::Arena<crate::Expression>,
) -> Result<u32, U32EvalError> {
match self.eval_expr_to_literal_from(handle, arena) {
Some(crate::Literal::U32(value)) => Ok(value),
Some(crate::Literal::I32(value)) => {
value.try_into().map_err(|_| U32EvalError::Negative)
}
_ => Err(U32EvalError::NonConst),
}
}
pub(crate) fn eval_expr_to_literal(
&self,
handle: crate::Handle<crate::Expression>,
) -> Option<crate::Literal> {
self.eval_expr_to_literal_from(handle, self.const_expressions)
}
pub(crate) fn eval_expr_to_literal_from(
&self,
handle: crate::Handle<crate::Expression>,
arena: &crate::Arena<crate::Expression>,
) -> Option<crate::Literal> {
fn get(
gctx: GlobalCtx,
handle: crate::Handle<crate::Expression>,
arena: &crate::Arena<crate::Expression>,
) -> Result<u32, U32EvalError> {
) -> Option<crate::Literal> {
match arena[handle] {
crate::Expression::Literal(crate::Literal::U32(value)) => Ok(value),
crate::Expression::Literal(crate::Literal::I32(value)) => {
value.try_into().map_err(|_| U32EvalError::Negative)
}
crate::Expression::ZeroValue(ty)
if matches!(
gctx.types[ty].inner,
crate::TypeInner::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
width: _
}
) =>
{
Ok(0)
}
_ => Err(U32EvalError::NonConst),
crate::Expression::Literal(literal) => Some(literal),
crate::Expression::ZeroValue(ty) => match gctx.types[ty].inner {
crate::TypeInner::Scalar { kind, width } => crate::Literal::zero(kind, width),
_ => None,
},
_ => None,
}
}
match arena[handle] {

View File

@@ -134,6 +134,8 @@ pub enum ConstExpressionError {
NonConst,
#[error(transparent)]
Compose(#[from] super::ComposeError),
#[error("Splatting {0:?} can't be done")]
InvalidSplatType(Handle<crate::Expression>),
#[error("Type resolution failed")]
Type(#[from] ResolveError),
#[error(transparent)]
@@ -196,6 +198,10 @@ impl super::Validator {
components.iter().map(|&handle| mod_info[handle].clone()),
)?;
}
E::Splat { value, .. } => match *mod_info[value].inner_with(gctx.types) {
crate::TypeInner::Scalar { .. } => {}
_ => return Err(super::ConstExpressionError::InvalidSplatType(value)),
},
_ => return Err(super::ConstExpressionError::NonConst),
}