[wgsl-in] add support for override declarations (#4793)

Co-authored-by: Jim Blandy <jimb@red-bean.com>
This commit is contained in:
Teodor Tanasoaia
2023-12-07 20:19:43 +01:00
parent b3dfc40c9d
commit f949ea69c4
37 changed files with 515 additions and 28 deletions

View File

@@ -404,6 +404,7 @@ fn write_function_expressions(
let (label, color_id) = match *expression {
E::Literal(_) => ("Literal".into(), 2),
E::Constant(_) => ("Constant".into(), 2),
E::Override(_) => ("Override".into(), 2),
E::ZeroValue(_) => ("ZeroValue".into(), 2),
E::Compose { ref components, .. } => {
payload = Some(Payload::Arguments(components));

View File

@@ -2538,6 +2538,7 @@ impl<'a, W: Write> Writer<'a, W> {
|writer, expr| writer.write_expr(expr, ctx),
)?;
}
Expression::Override(_) => return Err(Error::Custom("overrides are WIP".into())),
// `Access` is applied to arrays, vectors and matrices and is written as indexing
Expression::Access { base, index } => {
self.write_expr(base, ctx)?;

View File

@@ -2141,6 +2141,9 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|writer, expr| writer.write_expr(module, expr, func_ctx),
)?;
}
Expression::Override(_) => {
return Err(Error::Unimplemented("overrides are WIP".into()))
}
// All of the multiplication can be expressed as `mul`,
// except vector * vector, which needs to use the "*" operator.
Expression::Binary {

View File

@@ -1431,6 +1431,9 @@ impl<W: Write> Writer<W> {
|writer, context, expr| writer.put_expression(expr, context, true),
)?;
}
crate::Expression::Override(_) => {
return Err(Error::FeatureNotImplemented("overrides are WIP".into()))
}
crate::Expression::Access { base, .. }
| crate::Expression::AccessIndex { base, .. } => {
// This is an acceptable place to generate a `ReadZeroSkipWrite` check.

View File

@@ -239,6 +239,9 @@ impl<'w> BlockContext<'w> {
let init = self.ir_module.constants[handle].init;
self.writer.constant_ids[init.index()]
}
crate::Expression::Override(_) => {
return Err(Error::FeatureNotImplemented("overrides are WIP"))
}
crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
crate::Expression::Compose { ty, ref components } => {
self.temp_list.clear();

View File

@@ -1199,6 +1199,9 @@ impl<W: Write> Writer<W> {
|writer, expr| writer.write_expr(module, expr, func_ctx),
)?;
}
Expression::Override(_) => {
return Err(Error::Unimplemented("overrides are WIP".into()))
}
Expression::FunctionArgument(pos) => {
let name_key = func_ctx.argument_key(pos);
let name = &self.names[&name_key];

View File

@@ -3,6 +3,7 @@ use crate::arena::{Arena, Handle};
pub struct ExpressionTracer<'tracer> {
pub constants: &'tracer Arena<crate::Constant>,
pub overrides: &'tracer Arena<crate::Override>,
/// The arena in which we are currently tracing expressions.
pub expressions: &'tracer Arena<crate::Expression>,
@@ -88,6 +89,11 @@ impl<'tracer> ExpressionTracer<'tracer> {
None => self.expressions_used.insert(init),
}
}
Ex::Override(_) => {
// All overrides are considered used by definition. We mark
// their types and initialization expressions as used in
// `compact::compact`, so we have no more work to do here.
}
Ex::ZeroValue(ty) => self.types_used.insert(ty),
Ex::Compose { ty, ref components } => {
self.types_used.insert(ty);
@@ -219,6 +225,9 @@ impl ModuleMap {
| Ex::CallResult(_)
| Ex::RayQueryProceedResult => {}
// All overrides are retained, so their handles never change.
Ex::Override(_) => {}
// Expressions that contain handles that need to be adjusted.
Ex::Constant(ref mut constant) => self.constants.adjust(constant),
Ex::ZeroValue(ref mut ty) => self.types.adjust(ty),

View File

@@ -4,6 +4,7 @@ use super::{FunctionMap, ModuleMap};
pub struct FunctionTracer<'a> {
pub function: &'a crate::Function,
pub constants: &'a crate::Arena<crate::Constant>,
pub overrides: &'a crate::Arena<crate::Override>,
pub types_used: &'a mut HandleSet<crate::Type>,
pub constants_used: &'a mut HandleSet<crate::Constant>,
@@ -47,6 +48,7 @@ impl<'a> FunctionTracer<'a> {
fn as_expression(&mut self) -> super::expressions::ExpressionTracer {
super::expressions::ExpressionTracer {
constants: self.constants,
overrides: self.overrides,
expressions: &self.function.expressions,
types_used: self.types_used,

View File

@@ -54,6 +54,14 @@ pub fn compact(module: &mut crate::Module) {
}
}
// We treat all overrides as used by definition.
for (_, override_) in module.overrides.iter() {
module_tracer.types_used.insert(override_.ty);
if let Some(init) = override_.init {
module_tracer.const_expressions_used.insert(init);
}
}
// We assume that all functions are used.
//
// Observe which types, constant expressions, constants, and
@@ -158,6 +166,15 @@ pub fn compact(module: &mut crate::Module) {
}
});
// Adjust override types and initializers.
log::trace!("adjusting overrides");
for (_, override_) in module.overrides.iter_mut() {
module_map.types.adjust(&mut override_.ty);
if let Some(init) = override_.init.as_mut() {
module_map.const_expressions.adjust(init);
}
}
// Adjust global variables' types and initializers.
log::trace!("adjusting global variables");
for (_, global) in module.global_variables.iter_mut() {
@@ -235,6 +252,7 @@ impl<'module> ModuleTracer<'module> {
expressions::ExpressionTracer {
expressions: &self.module.const_expressions,
constants: &self.module.constants,
overrides: &self.module.overrides,
types_used: &mut self.types_used,
constants_used: &mut self.constants_used,
expressions_used: &mut self.const_expressions_used,
@@ -249,6 +267,7 @@ impl<'module> ModuleTracer<'module> {
FunctionTracer {
function,
constants: &self.module.constants,
overrides: &self.module.overrides,
types_used: &mut self.types_used,
constants_used: &mut self.constants_used,
const_expressions_used: &mut self.const_expressions_used,

View File

@@ -128,6 +128,7 @@ impl<I: Iterator<Item = u32>> super::Frontend<I> {
expressions: &mut fun.expressions,
local_arena: &mut fun.local_variables,
const_arena: &mut module.constants,
overrides: &mut module.overrides,
const_expressions: &mut module.const_expressions,
type_arena: &module.types,
global_arena: &module.global_variables,
@@ -581,6 +582,7 @@ impl<'function> BlockContext<'function> {
crate::proc::GlobalCtx {
types: self.type_arena,
constants: self.const_arena,
overrides: self.overrides,
const_expressions: self.const_expressions,
}
}

View File

@@ -531,6 +531,7 @@ struct BlockContext<'function> {
local_arena: &'function mut Arena<crate::LocalVariable>,
/// Constants arena of the module being processed
const_arena: &'function mut Arena<crate::Constant>,
overrides: &'function mut Arena<crate::Override>,
const_expressions: &'function mut Arena<crate::Expression>,
/// Type arena of the module being processed
type_arena: &'function UniqueArena<crate::Type>,
@@ -3934,7 +3935,7 @@ impl<I: Iterator<Item = u32>> Frontend<I> {
Op::TypeImage => self.parse_type_image(inst, &mut module),
Op::TypeSampledImage => self.parse_type_sampled_image(inst),
Op::TypeSampler => self.parse_type_sampler(inst, &mut module),
Op::Constant | Op::SpecConstant => self.parse_constant(inst, &mut module),
Op::Constant => self.parse_constant(inst, &mut module),
Op::ConstantComposite => self.parse_composite_constant(inst, &mut module),
Op::ConstantNull | Op::Undef => self.parse_null_constant(inst, &mut module),
Op::ConstantTrue => self.parse_bool_constant(inst, true, &mut module),

View File

@@ -190,7 +190,7 @@ pub enum Error<'a> {
expected: String,
got: String,
},
MissingType(Span),
DeclMissingTypeAndInit(Span),
MissingAttribute(&'static str, Span),
InvalidAtomicPointer(Span),
InvalidAtomicOperandType(Span),
@@ -273,6 +273,7 @@ pub enum Error<'a> {
span: Span,
limit: u8,
},
PipelineConstantIDValue(Span),
}
impl<'a> Error<'a> {
@@ -522,11 +523,11 @@ impl<'a> Error<'a> {
notes: vec![],
}
}
Error::MissingType(name_span) => ParseError {
message: format!("variable `{}` needs a type", &source[name_span]),
Error::DeclMissingTypeAndInit(name_span) => ParseError {
message: format!("declaration of `{}` needs a type specifier or initializer", &source[name_span]),
labels: vec![(
name_span,
format!("definition of `{}`", &source[name_span]).into(),
"needs a type specifier or initializer".into(),
)],
notes: vec![],
},
@@ -781,6 +782,14 @@ impl<'a> Error<'a> {
format!("nesting limit is currently set to {limit}"),
],
},
Error::PipelineConstantIDValue(span) => ParseError {
message: "pipeline constant ID must be between 0 and 65535 inclusive".to_string(),
labels: vec![(
span,
"must be between 0 and 65535 inclusive".into(),
)],
notes: vec![],
},
}
}
}

View File

@@ -187,6 +187,7 @@ const fn decl_ident<'a>(decl: &ast::GlobalDecl<'a>) -> ast::Ident<'a> {
ast::GlobalDeclKind::Fn(ref f) => f.name,
ast::GlobalDeclKind::Var(ref v) => v.name,
ast::GlobalDeclKind::Const(ref c) => c.name,
ast::GlobalDeclKind::Override(ref o) => o.name,
ast::GlobalDeclKind::Struct(ref s) => s.name,
ast::GlobalDeclKind::Type(ref t) => t.name,
}

View File

@@ -786,6 +786,7 @@ enum LoweredGlobalDecl {
Function(Handle<crate::Function>),
Var(Handle<crate::GlobalVariable>),
Const(Handle<crate::Constant>),
Override(Handle<crate::Override>),
Type(Handle<crate::Type>),
EntryPoint,
}
@@ -965,6 +966,65 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ctx.globals
.insert(c.name.name, LoweredGlobalDecl::Const(handle));
}
ast::GlobalDeclKind::Override(ref o) => {
let init = o
.init
.map(|init| self.expression(init, &mut ctx.as_const()))
.transpose()?;
let inferred_type = init
.map(|init| ctx.as_const().register_type(init))
.transpose()?;
let explicit_ty =
o.ty.map(|ty| self.resolve_ast_type(ty, &mut ctx))
.transpose()?;
let id =
o.id.map(|id| self.const_u32(id, &mut ctx.as_const()))
.transpose()?;
let id = if let Some((id, id_span)) = id {
Some(
u16::try_from(id)
.map_err(|_| Error::PipelineConstantIDValue(id_span))?,
)
} else {
None
};
let ty = match (explicit_ty, inferred_type) {
(Some(explicit_ty), Some(inferred_type)) => {
if explicit_ty == inferred_type {
explicit_ty
} else {
let gctx = ctx.module.to_ctx();
return Err(Error::InitializationTypeMismatch {
name: o.name.span,
expected: explicit_ty.to_wgsl(&gctx),
got: inferred_type.to_wgsl(&gctx),
});
}
}
(Some(explicit_ty), None) => explicit_ty,
(None, Some(inferred_type)) => inferred_type,
(None, None) => {
return Err(Error::DeclMissingTypeAndInit(o.name.span));
}
};
let handle = ctx.module.overrides.append(
crate::Override {
name: Some(o.name.name.to_string()),
id,
ty,
init,
},
span,
);
ctx.globals
.insert(o.name.name, LoweredGlobalDecl::Override(handle));
}
ast::GlobalDeclKind::Struct(ref s) => {
let handle = self.r#struct(s, span, &mut ctx)?;
ctx.globals
@@ -1202,7 +1262,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ty = explicit_ty;
initializer = None;
}
(None, None) => return Err(Error::MissingType(v.name.span)),
(None, None) => return Err(Error::DeclMissingTypeAndInit(v.name.span)),
}
let (const_initializer, initializer) = {
@@ -1818,9 +1878,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
)?;
Ok(Some(handle))
}
Some(&LoweredGlobalDecl::Const(_) | &LoweredGlobalDecl::Var(_)) => {
Err(Error::Unexpected(function.span, ExpectedToken::Function))
}
Some(
&LoweredGlobalDecl::Const(_)
| &LoweredGlobalDecl::Override(_)
| &LoweredGlobalDecl::Var(_),
) => Err(Error::Unexpected(function.span, ExpectedToken::Function)),
Some(&LoweredGlobalDecl::EntryPoint) => Err(Error::CalledEntryPoint(function.span)),
Some(&LoweredGlobalDecl::Function(function)) => {
let arguments = arguments

View File

@@ -82,6 +82,7 @@ pub enum GlobalDeclKind<'a> {
Fn(Function<'a>),
Var(GlobalVariable<'a>),
Const(Const<'a>),
Override(Override<'a>),
Struct(Struct<'a>),
Type(TypeAlias<'a>),
}
@@ -200,6 +201,14 @@ pub struct Const<'a> {
pub init: Handle<Expression<'a>>,
}
#[derive(Debug)]
pub struct Override<'a> {
pub name: Ident<'a>,
pub id: Option<Handle<Expression<'a>>>,
pub ty: Option<Handle<Type<'a>>>,
pub init: Option<Handle<Expression<'a>>>,
}
/// The size of an [`Array`] or [`BindingArray`].
///
/// [`Array`]: Type::Array

View File

@@ -2180,6 +2180,7 @@ impl Parser {
let mut early_depth_test = ParsedAttribute::default();
let (mut bind_index, mut bind_group) =
(ParsedAttribute::default(), ParsedAttribute::default());
let mut id = ParsedAttribute::default();
let mut dependencies = FastIndexSet::default();
let mut ctx = ExpressionContext {
@@ -2203,6 +2204,11 @@ impl Parser {
bind_group.set(self.general_expression(lexer, &mut ctx)?, name_span)?;
lexer.expect(Token::Paren(')'))?;
}
("id", name_span) => {
lexer.expect(Token::Paren('('))?;
id.set(self.general_expression(lexer, &mut ctx)?, name_span)?;
lexer.expect(Token::Paren(')'))?;
}
("vertex", name_span) => {
stage.set(crate::ShaderStage::Vertex, name_span)?;
}
@@ -2293,6 +2299,30 @@ impl Parser {
Some(ast::GlobalDeclKind::Const(ast::Const { name, ty, init }))
}
(Token::Word("override"), _) => {
let name = lexer.next_ident()?;
let ty = if lexer.skip(Token::Separator(':')) {
Some(self.type_decl(lexer, &mut ctx)?)
} else {
None
};
let init = if lexer.skip(Token::Operation('=')) {
Some(self.general_expression(lexer, &mut ctx)?)
} else {
None
};
lexer.expect(Token::Separator(';'))?;
Some(ast::GlobalDeclKind::Override(ast::Override {
name,
id: id.value,
ty,
init,
}))
}
(Token::Word("var"), _) => {
let mut var = self.variable_decl(lexer, &mut ctx)?;
var.binding = binding.take();

View File

@@ -226,6 +226,7 @@ mod tests {
let gctx = crate::proc::GlobalCtx {
types: &types,
constants: &crate::Arena::new(),
overrides: &crate::Arena::new(),
const_expressions: &crate::Arena::new(),
};
let array = crate::TypeInner::Array {

View File

@@ -175,7 +175,7 @@ tree.
A Naga *constant expression* is one of the following [`Expression`]
variants, whose operands (if any) are also constant expressions:
- [`Literal`]
- [`Constant`], for [`Constant`s][const_type] whose `override` is `None`
- [`Constant`], for [`Constant`]s
- [`ZeroValue`], for fixed-size types
- [`Compose`]
- [`Access`]
@@ -194,8 +194,7 @@ A constant expression can be evaluated at module translation time.
## Override expressions
A Naga *override expression* is the same as a [constant expression],
except that it is also allowed to refer to [`Constant`s][const_type]
whose `override` is something other than `None`.
except that it is also allowed to reference other [`Override`]s.
An override expression can be evaluated at pipeline creation time.
@@ -238,8 +237,6 @@ An override expression can be evaluated at pipeline creation time.
[`Math`]: Expression::Math
[`As`]: Expression::As
[const_type]: Constant
[constant expression]: index.html#constant-expressions
*/
@@ -890,6 +887,25 @@ pub enum Literal {
AbstractFloat(f64),
}
/// Pipeline-overridable constant.
#[derive(Debug, PartialEq)]
#[cfg_attr(feature = "clone", derive(Clone))]
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[cfg_attr(feature = "deserialize", derive(Deserialize))]
#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
pub struct Override {
pub name: Option<String>,
/// Pipeline Constant ID.
pub id: Option<u16>,
pub ty: Handle<Type>,
/// The default value of the pipeline-overridable constant.
///
/// This [`Handle`] refers to [`Module::const_expressions`], not
/// any [`Function::expressions`] arena.
pub init: Option<Handle<Expression>>,
}
/// Constant value.
#[derive(Debug, PartialEq)]
#[cfg_attr(feature = "clone", derive(Clone))]
@@ -904,13 +920,6 @@ pub struct Constant {
///
/// This [`Handle`] refers to [`Module::const_expressions`], not
/// any [`Function::expressions`] arena.
///
/// If `override` is `None`, then this must be a Naga
/// [constant expression]. Otherwise, this may be a Naga
/// [override expression] or [constant expression].
///
/// [constant expression]: index.html#constant-expressions
/// [override expression]: index.html#override-expressions
pub init: Handle<Expression>,
}
@@ -1299,6 +1308,8 @@ pub enum Expression {
Literal(Literal),
/// Constant value.
Constant(Handle<Constant>),
/// Pipeline-overridable constant.
Override(Handle<Override>),
/// Zero value of a type.
ZeroValue(Handle<Type>),
/// Composite expression.
@@ -2053,6 +2064,8 @@ pub struct Module {
pub special_types: SpecialTypes,
/// Arena for the constants defined in this module.
pub constants: Arena<Constant>,
/// Arena for the pipeline-overridable constants defined in this module.
pub overrides: Arena<Override>,
/// Arena for the global variables defined in this module.
pub global_variables: Arena<GlobalVariable>,
/// [Constant expressions] and [override expressions] used by this module.

View File

@@ -4,8 +4,8 @@ use arrayvec::ArrayVec;
use crate::{
arena::{Arena, Handle, UniqueArena},
ArraySize, BinaryOperator, Constant, Expression, Literal, ScalarKind, Span, Type, TypeInner,
UnaryOperator,
ArraySize, BinaryOperator, Constant, Expression, Literal, Override, ScalarKind, Span, Type,
TypeInner, UnaryOperator,
};
/// A macro that allows dollar signs (`$`) to be emitted by other macros. Useful for generating
@@ -291,6 +291,9 @@ pub struct ConstantEvaluator<'a> {
/// The module's constant arena.
constants: &'a Arena<Constant>,
/// The module's override arena.
overrides: &'a Arena<Override>,
/// The arena to which we are contributing expressions.
expressions: &'a mut Arena<Expression>,
@@ -456,6 +459,7 @@ impl<'a> ConstantEvaluator<'a> {
behavior,
types: &mut module.types,
constants: &module.constants,
overrides: &module.overrides,
expressions: &mut module.const_expressions,
function_local_data: None,
}
@@ -515,6 +519,7 @@ impl<'a> ConstantEvaluator<'a> {
behavior,
types: &mut module.types,
constants: &module.constants,
overrides: &module.overrides,
expressions,
function_local_data: Some(FunctionLocalData {
const_expressions: &module.const_expressions,
@@ -529,6 +534,7 @@ impl<'a> ConstantEvaluator<'a> {
crate::proc::GlobalCtx {
types: self.types,
constants: self.constants,
overrides: self.overrides,
const_expressions: match self.function_local_data {
Some(ref data) => data.const_expressions,
None => self.expressions,
@@ -605,6 +611,9 @@ impl<'a> ConstantEvaluator<'a> {
// This is mainly done to avoid having constants pointing to other constants.
Ok(self.constants[c].init)
}
Expression::Override(_) => Err(ConstantEvaluatorError::NotImplemented(
"overrides are WIP".into(),
)),
Expression::Literal(_) | Expression::ZeroValue(_) | Expression::Constant(_) => {
self.register_evaluated_expr(expr.clone(), span)
}
@@ -2035,6 +2044,7 @@ mod tests {
fn unary_op() {
let mut types = UniqueArena::new();
let mut constants = Arena::new();
let overrides = Arena::new();
let mut const_expressions = Arena::new();
let scalar_ty = types.insert(
@@ -2113,6 +2123,7 @@ mod tests {
behavior: Behavior::Wgsl,
types: &mut types,
constants: &constants,
overrides: &overrides,
expressions: &mut const_expressions,
function_local_data: None,
};
@@ -2164,6 +2175,7 @@ mod tests {
fn cast() {
let mut types = UniqueArena::new();
let mut constants = Arena::new();
let overrides = Arena::new();
let mut const_expressions = Arena::new();
let scalar_ty = types.insert(
@@ -2196,6 +2208,7 @@ mod tests {
behavior: Behavior::Wgsl,
types: &mut types,
constants: &constants,
overrides: &overrides,
expressions: &mut const_expressions,
function_local_data: None,
};
@@ -2214,6 +2227,7 @@ mod tests {
fn access() {
let mut types = UniqueArena::new();
let mut constants = Arena::new();
let overrides = Arena::new();
let mut const_expressions = Arena::new();
let matrix_ty = types.insert(
@@ -2311,6 +2325,7 @@ mod tests {
behavior: Behavior::Wgsl,
types: &mut types,
constants: &constants,
overrides: &overrides,
expressions: &mut const_expressions,
function_local_data: None,
};
@@ -2364,6 +2379,7 @@ mod tests {
fn compose_of_constants() {
let mut types = UniqueArena::new();
let mut constants = Arena::new();
let overrides = Arena::new();
let mut const_expressions = Arena::new();
let i32_ty = types.insert(
@@ -2401,6 +2417,7 @@ mod tests {
behavior: Behavior::Wgsl,
types: &mut types,
constants: &constants,
overrides: &overrides,
expressions: &mut const_expressions,
function_local_data: None,
};
@@ -2443,6 +2460,7 @@ mod tests {
fn splat_of_constant() {
let mut types = UniqueArena::new();
let mut constants = Arena::new();
let overrides = Arena::new();
let mut const_expressions = Arena::new();
let i32_ty = types.insert(
@@ -2480,6 +2498,7 @@ mod tests {
behavior: Behavior::Wgsl,
types: &mut types,
constants: &constants,
overrides: &overrides,
expressions: &mut const_expressions,
function_local_data: None,
};

View File

@@ -648,6 +648,7 @@ impl crate::Module {
GlobalCtx {
types: &self.types,
constants: &self.constants,
overrides: &self.overrides,
const_expressions: &self.const_expressions,
}
}
@@ -663,6 +664,7 @@ pub(super) enum U32EvalError {
pub struct GlobalCtx<'a> {
pub types: &'a crate::UniqueArena<crate::Type>,
pub constants: &'a crate::Arena<crate::Constant>,
pub overrides: &'a crate::Arena<crate::Override>,
pub const_expressions: &'a crate::Arena<crate::Expression>,
}

View File

@@ -185,6 +185,7 @@ pub enum ResolveError {
pub struct ResolveContext<'a> {
pub constants: &'a Arena<crate::Constant>,
pub overrides: &'a Arena<crate::Override>,
pub types: &'a UniqueArena<crate::Type>,
pub special_types: &'a crate::SpecialTypes,
pub global_vars: &'a Arena<crate::GlobalVariable>,
@@ -202,6 +203,7 @@ impl<'a> ResolveContext<'a> {
) -> Self {
Self {
constants: &module.constants,
overrides: &module.overrides,
types: &module.types,
special_types: &module.special_types,
global_vars: &module.global_variables,
@@ -407,6 +409,7 @@ impl<'a> ResolveContext<'a> {
},
crate::Expression::Literal(lit) => TypeResolution::Value(lit.ty_inner()),
crate::Expression::Constant(h) => TypeResolution::Handle(self.constants[h].ty),
crate::Expression::Override(h) => TypeResolution::Handle(self.overrides[h].ty),
crate::Expression::ZeroValue(ty) => TypeResolution::Handle(ty),
crate::Expression::Compose { ty, .. } => TypeResolution::Handle(ty),
crate::Expression::FunctionArgument(index) => {

View File

@@ -574,7 +574,7 @@ impl FunctionInfo {
non_uniform_result: self.add_ref(vector),
requirements: UniformityRequirements::empty(),
},
E::Literal(_) | E::Constant(_) | E::ZeroValue(_) => Uniformity::new(),
E::Literal(_) | E::Constant(_) | E::Override(_) | E::ZeroValue(_) => Uniformity::new(),
E::Compose { ref components, .. } => {
let non_uniform_result = components
.iter()
@@ -1186,6 +1186,7 @@ fn uniform_control_flow() {
};
let resolve_context = ResolveContext {
constants: &Arena::new(),
overrides: &Arena::new(),
types: &type_arena,
special_types: &crate::SpecialTypes::default(),
global_vars: &global_var_arena,

View File

@@ -345,7 +345,7 @@ impl super::Validator {
self.validate_literal(literal)?;
ShaderStages::all()
}
E::Constant(_) | E::ZeroValue(_) => ShaderStages::all(),
E::Constant(_) | E::Override(_) | E::ZeroValue(_) => ShaderStages::all(),
E::Compose { ref components, ty } => {
validate_compose(
ty,

View File

@@ -31,6 +31,7 @@ impl super::Validator {
pub(super) fn validate_module_handles(module: &crate::Module) -> Result<(), ValidationError> {
let &crate::Module {
ref constants,
ref overrides,
ref entry_points,
ref functions,
ref global_variables,
@@ -68,7 +69,7 @@ impl super::Validator {
}
for handle_and_expr in const_expressions.iter() {
Self::validate_const_expression_handles(handle_and_expr, constants, types)?;
Self::validate_const_expression_handles(handle_and_expr, constants, overrides, types)?;
}
let validate_type = |handle| Self::validate_type_handle(handle, types);
@@ -81,6 +82,19 @@ impl super::Validator {
validate_const_expr(init)?;
}
for (_handle, override_) in overrides.iter() {
let &crate::Override {
name: _,
id: _,
ty,
init,
} = override_;
validate_type(ty)?;
if let Some(init_expr) = init {
validate_const_expr(init_expr)?;
}
}
for (_handle, global_variable) in global_variables.iter() {
let &crate::GlobalVariable {
name: _,
@@ -135,6 +149,7 @@ impl super::Validator {
Self::validate_expression_handles(
handle_and_expr,
constants,
overrides,
const_expressions,
types,
local_variables,
@@ -181,6 +196,13 @@ impl super::Validator {
handle.check_valid_for(constants).map(|_| ())
}
fn validate_override_handle(
handle: Handle<crate::Override>,
overrides: &Arena<crate::Override>,
) -> Result<(), InvalidHandleError> {
handle.check_valid_for(overrides).map(|_| ())
}
fn validate_expression_handle(
handle: Handle<crate::Expression>,
expressions: &Arena<crate::Expression>,
@@ -198,9 +220,11 @@ impl super::Validator {
fn validate_const_expression_handles(
(handle, expression): (Handle<crate::Expression>, &crate::Expression),
constants: &Arena<crate::Constant>,
overrides: &Arena<crate::Override>,
types: &UniqueArena<crate::Type>,
) -> Result<(), InvalidHandleError> {
let validate_constant = |handle| Self::validate_constant_handle(handle, constants);
let validate_override = |handle| Self::validate_override_handle(handle, overrides);
let validate_type = |handle| Self::validate_type_handle(handle, types);
match *expression {
@@ -209,6 +233,12 @@ impl super::Validator {
validate_constant(constant)?;
handle.check_dep(constants[constant].init)?;
}
crate::Expression::Override(override_) => {
validate_override(override_)?;
if let Some(init) = overrides[override_].init {
handle.check_dep(init)?;
}
}
crate::Expression::ZeroValue(ty) => {
validate_type(ty)?;
}
@@ -225,6 +255,7 @@ impl super::Validator {
fn validate_expression_handles(
(handle, expression): (Handle<crate::Expression>, &crate::Expression),
constants: &Arena<crate::Constant>,
overrides: &Arena<crate::Override>,
const_expressions: &Arena<crate::Expression>,
types: &UniqueArena<crate::Type>,
local_variables: &Arena<crate::LocalVariable>,
@@ -234,6 +265,7 @@ impl super::Validator {
current_function: Option<Handle<crate::Function>>,
) -> Result<(), InvalidHandleError> {
let validate_constant = |handle| Self::validate_constant_handle(handle, constants);
let validate_override = |handle| Self::validate_override_handle(handle, overrides);
let validate_const_expr =
|handle| Self::validate_expression_handle(handle, const_expressions);
let validate_type = |handle| Self::validate_type_handle(handle, types);
@@ -255,6 +287,9 @@ impl super::Validator {
crate::Expression::Constant(constant) => {
validate_constant(constant)?;
}
crate::Expression::Override(override_) => {
validate_override(override_)?;
}
crate::Expression::ZeroValue(ty) => {
validate_type(ty)?;
}
@@ -659,6 +694,7 @@ fn constant_deps() {
let mut const_exprs = Arena::new();
let mut fun_exprs = Arena::new();
let mut constants = Arena::new();
let overrides = Arena::new();
let i32_handle = types.insert(
Type {
@@ -686,6 +722,7 @@ fn constant_deps() {
assert!(super::Validator::validate_const_expression_handles(
handle_and_expr,
&constants,
&overrides,
&types,
)
.is_err());

View File

@@ -184,6 +184,16 @@ pub enum ConstantError {
NonConstructibleType,
}
#[derive(Clone, Debug, thiserror::Error)]
pub enum OverrideError {
#[error("The type doesn't match the override")]
InvalidType,
#[error("The type is not constructible")]
NonConstructibleType,
#[error("The type is not a scalar")]
TypeNotScalar,
}
#[derive(Clone, Debug, thiserror::Error)]
pub enum ValidationError {
#[error(transparent)]
@@ -207,6 +217,12 @@ pub enum ValidationError {
name: String,
source: ConstantError,
},
#[error("Override {handle:?} '{name}' is invalid")]
Override {
handle: Handle<crate::Override>,
name: String,
source: OverrideError,
},
#[error("Global variable {handle:?} '{name}' is invalid")]
GlobalVariable {
handle: Handle<crate::GlobalVariable>,
@@ -329,6 +345,35 @@ impl Validator {
Ok(())
}
fn validate_override(
&self,
handle: Handle<crate::Override>,
gctx: crate::proc::GlobalCtx,
mod_info: &ModuleInfo,
) -> Result<(), OverrideError> {
let o = &gctx.overrides[handle];
let type_info = &self.types[o.ty.index()];
if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
return Err(OverrideError::NonConstructibleType);
}
let decl_ty = &gctx.types[o.ty].inner;
match decl_ty {
&crate::TypeInner::Scalar(_) => {}
_ => return Err(OverrideError::TypeNotScalar),
}
if let Some(init) = o.init {
let init_ty = mod_info[init].inner_with(gctx.types);
if !decl_ty.equivalent(init_ty, gctx.types) {
return Err(OverrideError::InvalidType);
}
}
Ok(())
}
/// Check the given module to be valid.
pub fn validate(
&mut self,
@@ -406,6 +451,18 @@ impl Validator {
.with_span_handle(handle, &module.constants)
})?
}
for (handle, override_) in module.overrides.iter() {
self.validate_override(handle, module.to_ctx(), &mod_info)
.map_err(|source| {
ValidationError::Override {
handle,
name: override_.name.clone().unwrap_or_default(),
source,
}
.with_span_handle(handle, &module.overrides)
})?
}
}
for (var_handle, var) in module.global_variables.iter() {

View File

@@ -0,0 +1,14 @@
@id(0) override has_point_light: bool = true; // Algorithmic control
@id(1200) override specular_param: f32 = 2.3; // Numeric control
@id(1300) override gain: f32; // Must be overridden
override width: f32 = 0.0; // Specified at the API level using
// the name "width".
override depth: f32; // Specified at the API level using
// the name "depth".
// Must be overridden.
// override height = 2 * depth; // The default value
// (if not set at the API level),
// depends on another
// overridable constant.
override inferred_f32 = 2.718;

View File

@@ -0,0 +1,26 @@
(
type_flags: [
("DATA | SIZED | COPY | ARGUMENT | CONSTRUCTIBLE"),
("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"),
],
functions: [],
entry_points: [],
const_expression_types: [
Value(Scalar((
kind: Bool,
width: 1,
))),
Value(Scalar((
kind: Float,
width: 4,
))),
Value(Scalar((
kind: Float,
width: 4,
))),
Value(Scalar((
kind: Float,
width: 4,
))),
],
)

View File

@@ -324,6 +324,7 @@
predeclared_types: {},
),
constants: [],
overrides: [],
global_variables: [
(
name: Some("global_const"),

View File

@@ -324,6 +324,7 @@
predeclared_types: {},
),
constants: [],
overrides: [],
global_variables: [
(
name: Some("global_const"),

View File

@@ -46,6 +46,7 @@
predeclared_types: {},
),
constants: [],
overrides: [],
global_variables: [
(
name: Some("v_indices"),

View File

@@ -46,6 +46,7 @@
predeclared_types: {},
),
constants: [],
overrides: [],
global_variables: [
(
name: Some("v_indices"),

View File

@@ -0,0 +1,71 @@
(
types: [
(
name: None,
inner: Scalar((
kind: Bool,
width: 1,
)),
),
(
name: None,
inner: Scalar((
kind: Float,
width: 4,
)),
),
],
special_types: (
ray_desc: None,
ray_intersection: None,
predeclared_types: {},
),
constants: [],
overrides: [
(
name: Some("has_point_light"),
id: Some(0),
ty: 1,
init: Some(1),
),
(
name: Some("specular_param"),
id: Some(1200),
ty: 2,
init: Some(2),
),
(
name: Some("gain"),
id: Some(1300),
ty: 2,
init: None,
),
(
name: Some("width"),
id: None,
ty: 2,
init: Some(3),
),
(
name: Some("depth"),
id: None,
ty: 2,
init: None,
),
(
name: Some("inferred_f32"),
id: None,
ty: 2,
init: Some(4),
),
],
global_variables: [],
const_expressions: [
Literal(Bool(true)),
Literal(F32(2.3)),
Literal(F32(0.0)),
Literal(F32(2.718)),
],
functions: [],
entry_points: [],
)

View File

@@ -0,0 +1,71 @@
(
types: [
(
name: None,
inner: Scalar((
kind: Bool,
width: 1,
)),
),
(
name: None,
inner: Scalar((
kind: Float,
width: 4,
)),
),
],
special_types: (
ray_desc: None,
ray_intersection: None,
predeclared_types: {},
),
constants: [],
overrides: [
(
name: Some("has_point_light"),
id: Some(0),
ty: 1,
init: Some(1),
),
(
name: Some("specular_param"),
id: Some(1200),
ty: 2,
init: Some(2),
),
(
name: Some("gain"),
id: Some(1300),
ty: 2,
init: None,
),
(
name: Some("width"),
id: None,
ty: 2,
init: Some(3),
),
(
name: Some("depth"),
id: None,
ty: 2,
init: None,
),
(
name: Some("inferred_f32"),
id: None,
ty: 2,
init: Some(4),
),
],
global_variables: [],
const_expressions: [
Literal(Bool(true)),
Literal(F32(2.3)),
Literal(F32(0.0)),
Literal(F32(2.718)),
],
functions: [],
entry_points: [],
)

View File

@@ -253,6 +253,7 @@
init: 22,
),
],
overrides: [],
global_variables: [
(
name: Some("t_shadow"),

View File

@@ -456,6 +456,7 @@
init: 38,
),
],
overrides: [],
global_variables: [
(
name: Some("t_shadow"),

View File

@@ -815,6 +815,14 @@ fn convert_wgsl() {
"int64",
Targets::SPIRV | Targets::HLSL | Targets::WGSL | Targets::METAL,
),
(
"overrides",
Targets::IR | Targets::ANALYSIS, // | Targets::SPIRV
// | Targets::METAL
// | Targets::GLSL
// | Targets::HLSL
// | Targets::WGSL,
),
];
for &(name, targets) in inputs.iter() {

View File

@@ -570,11 +570,11 @@ fn local_var_missing_type() {
var x;
}
"#,
r#"error: variable `x` needs a type
r#"error: declaration of `x` needs a type specifier or initializer
┌─ wgsl:3:21
3 │ var x;
│ ^ definition of `x`
│ ^ needs a type specifier or initializer
"#,
);