From beabd62d96f2fadac7ca599a5e76593c47a522eb Mon Sep 17 00:00:00 2001 From: Dzmitry Malyshau Date: Mon, 5 Jul 2021 16:06:30 -0400 Subject: [PATCH] wgsl: type inference for local variables --- src/front/glsl/ast.rs | 4 +-- src/front/glsl/functions.rs | 26 ++++++-------- src/front/mod.rs | 8 +++++ src/front/wgsl/mod.rs | 71 +++++++++++++++++++++++++++++++------ src/front/wgsl/tests.rs | 2 ++ tests/wgsl-errors.rs | 30 ++++++++-------- 6 files changed, 99 insertions(+), 42 deletions(-) diff --git a/src/front/glsl/ast.rs b/src/front/glsl/ast.rs index 1247adcb1f..ddeff5d5c5 100644 --- a/src/front/glsl/ast.rs +++ b/src/front/glsl/ast.rs @@ -492,8 +492,8 @@ impl<'function> Context<'function> { Some(self.add_expression(Expression::Load { pointer }, body)), meta, )); - }, - _ => {}, + } + _ => {} } } diff --git a/src/front/glsl/functions.rs b/src/front/glsl/functions.rs index 15d8ba7536..a5b9c56c19 100644 --- a/src/front/glsl/functions.rs +++ b/src/front/glsl/functions.rs @@ -1,18 +1,14 @@ use crate::{ - proc::ensure_block_returns, Arena, BinaryOperator, Block, Constant, ConstantInner, EntryPoint, Expression, Function, - FunctionArgument, FunctionResult, Handle, ImageQuery, LocalVariable, MathFunction, - RelationalFunction, SampleLevel, ScalarKind, ScalarValue, Statement, StructMember, SwizzleComponent, Type, - TypeInner, VectorSize, + proc::ensure_block_returns, Arena, BinaryOperator, Block, Constant, ConstantInner, EntryPoint, + Expression, Function, FunctionArgument, FunctionResult, Handle, ImageQuery, LocalVariable, + MathFunction, RelationalFunction, SampleLevel, ScalarKind, ScalarValue, Statement, + StructMember, SwizzleComponent, Type, TypeInner, VectorSize, }; use super::{ast::*, error::ErrorKind, SourceMetadata}; impl Program<'_> { - fn add_constant_value( - &mut self, - scalar_kind: ScalarKind, - value: u64, - ) -> Handle { + fn add_constant_value(&mut self, scalar_kind: ScalarKind, value: u64) -> Handle { let value = match scalar_kind { ScalarKind::Uint => ScalarValue::Uint(value), ScalarKind::Sint => ScalarValue::Sint(value as i64), @@ -23,10 +19,7 @@ impl Program<'_> { self.module.constants.fetch_or_append(Constant { name: None, specialization: None, - inner: ConstantInner::Scalar { - width: 4, - value, - }, + inner: ConstantInner::Scalar { width: 4, value }, }) } @@ -49,13 +42,16 @@ impl Program<'_> { let expr_type = self.resolve_type(ctx, args[0].0, args[0].1)?; let vector_size = match *expr_type { - TypeInner::Vector{ size, .. } => Some(size), + TypeInner::Vector { size, .. } => Some(size), _ => None, }; // Special case: if casting from a bool, we need to use Select and not As. match self.module.types[ty].inner.scalar_kind() { - Some(result_scalar_kind) if expr_type.scalar_kind() == Some(ScalarKind::Bool) && result_scalar_kind != ScalarKind::Bool => { + Some(result_scalar_kind) + if expr_type.scalar_kind() == Some(ScalarKind::Bool) + && result_scalar_kind != ScalarKind::Bool => + { let c0 = self.add_constant_value(result_scalar_kind, 0u64); let c1 = self.add_constant_value(result_scalar_kind, 1u64); let mut reject = ctx.add_expression(Expression::Constant(c0), body); diff --git a/src/front/mod.rs b/src/front/mod.rs index 8181592922..9989b21328 100644 --- a/src/front/mod.rs +++ b/src/front/mod.rs @@ -11,6 +11,7 @@ use crate::{ arena::{Arena, Handle}, proc::{ResolveContext, ResolveError, TypeResolution}, }; +use std::ops; /// Helper class to emit expressions #[allow(dead_code)] @@ -85,3 +86,10 @@ impl Typifier { Ok(()) } } + +impl ops::Index> for Typifier { + type Output = TypeResolution; + fn index(&self, handle: Handle) -> &Self::Output { + &self.resolutions[handle.index()] + } +} diff --git a/src/front/wgsl/mod.rs b/src/front/wgsl/mod.rs index 861f1e4174..86b29dd0ab 100644 --- a/src/front/wgsl/mod.rs +++ b/src/front/wgsl/mod.rs @@ -124,7 +124,8 @@ pub enum Error<'a> { ZeroSizeOrAlign(Span), InconsistentBinding(Span), UnknownLocalFunction(Span), - LetTypeMismatch(Span, Handle), + InitializationTypeMismatch(Span, Handle), + MissingType(Span), Other, } @@ -326,11 +327,16 @@ impl<'a> Error<'a> { labels: vec![(span.clone(), "unknown local function".into())], notes: vec![], }, - Error::LetTypeMismatch(ref name_span, ref expected_ty) => ParseError { + Error::InitializationTypeMismatch(ref name_span, ref expected_ty) => ParseError { message: format!("the type of `{}` is expected to be {:?}", &source[name_span.clone()], expected_ty), labels: vec![(name_span.clone(), format!("definition of `{}`", &source[name_span.clone()]).into())], notes: vec![], }, + Error::MissingType(ref name_span) => ParseError { + message: format!("variable `{}` needs a type", &source[name_span.clone()]), + labels: vec![(name_span.clone(), format!("definition of `{}`", &source[name_span.clone()]).into())], + notes: vec![], + }, Error::Other => ParseError { message: "other error".to_string(), labels: vec![], @@ -2567,7 +2573,7 @@ impl Parser { given_inner, expr_inner ); - return Err(Error::LetTypeMismatch(name_span, ty)); + return Err(Error::InitializationTypeMismatch(name_span, ty)); } } block.extend(emitter.finish(context.expressions)); @@ -2583,24 +2589,69 @@ impl Parser { Variable(Handle), } - let (name, _name_span, ty, _access) = - self.parse_variable_ident_decl(lexer, context.types, context.constants)?; + let (name, name_span) = lexer.next_ident_with_span()?; + let given_ty = if lexer.skip(Token::Separator(':')) { + let (ty, _access) = + self.parse_type_decl(lexer, None, context.types, context.constants)?; + Some(ty) + } else { + None + }; - let init = if lexer.skip(Token::Operation('=')) { + let (init, ty) = if lexer.skip(Token::Operation('=')) { emitter.start(context.expressions); let value = self.parse_general_expression( lexer, context.as_expression(block, &mut emitter), )?; block.extend(emitter.finish(context.expressions)); - match context.expressions[value] { + + // prepare the typifier, but work around mutable borrowing... + let _ = context + .as_expression(block, &mut emitter) + .resolve_type(value)?; + + //TODO: share more of this code with `let` arm + let ty = match given_ty { + Some(ty) => { + let expr_inner = context.typifier.get(value, context.types); + let given_inner = &context.types[ty].inner; + if given_inner != expr_inner { + log::error!( + "Given type {:?} doesn't match expected {:?}", + given_inner, + expr_inner + ); + return Err(Error::InitializationTypeMismatch(name_span, ty)); + } + ty + } + None => { + // register the type, if needed + match context.typifier[value].clone() { + TypeResolution::Handle(ty) => ty, + TypeResolution::Value(inner) => context + .types + .fetch_or_append(crate::Type { name: None, inner }), + } + } + }; + + let init = match context.expressions[value] { crate::Expression::Constant(handle) if is_uniform_control_flow => { Init::Constant(handle) } _ => Init::Variable(value), - } + }; + (init, ty) } else { - Init::Empty + match given_ty { + Some(ty) => (Init::Empty, ty), + None => { + log::error!("Variable '{}' without an initializer needs a type", name); + return Err(Error::MissingType(name_span)); + } + } }; lexer.expect(Token::Separator(';'))?; @@ -3186,7 +3237,7 @@ impl Parser { crate::ConstantInner::Composite { ty, components: _ } => ty == explicit_ty, }; if !type_match { - return Err(Error::LetTypeMismatch(name_span, explicit_ty)); + return Err(Error::InitializationTypeMismatch(name_span, explicit_ty)); } //TODO: check `ty` against `const_handle`. lexer.expect(Token::Separator(';'))?; diff --git a/src/front/wgsl/tests.rs b/src/front/wgsl/tests.rs index a537c62d1d..b1c7a12321 100644 --- a/src/front/wgsl/tests.rs +++ b/src/front/wgsl/tests.rs @@ -32,6 +32,8 @@ fn parse_type_inference() { fn foo() { let a = 2u; let b: u32 = a; + var x = 3f32; + var y = vec2(1, 2); }", ) .unwrap(); diff --git a/tests/wgsl-errors.rs b/tests/wgsl-errors.rs index 2c8086102a..0f87c70be0 100644 --- a/tests/wgsl-errors.rs +++ b/tests/wgsl-errors.rs @@ -146,7 +146,7 @@ fn bad_texture() { 7 │ return textureSample(a, sampler, vec2(0.0)); │ ^ not an image -"# +"#, ); } @@ -186,7 +186,7 @@ fn bad_texture_sample_type() { 3 │ [[group(0), binding(1)]] var texture : texture_2d; │ ^^^^ must be one of f32, i32 or u32 -"# +"#, ); } @@ -237,7 +237,7 @@ fn unknown_attribute() { 2 │ [[a]] │ ^ unknown attribute -"# +"#, ); } @@ -253,7 +253,7 @@ fn unknown_built_in() { 2 │ fn x([[builtin(unknown_built_in)]] y: u32) {} │ ^^^^^^^^^^^^^^^^ unknown builtin -"# +"#, ); } @@ -269,7 +269,7 @@ fn unknown_access() { 2 │ var x: [[access(unknown_access)]] array; │ ^^^^^^^^^^^^^^ unknown access -"# +"#, ); } @@ -285,7 +285,7 @@ fn unknown_shader_stage() { 2 │ [[stage(geometry)]] fn main() {} │ ^^^^^^^^ unknown shader stage -"# +"#, ); } @@ -303,7 +303,7 @@ fn unknown_ident() { 3 │ let a = b; │ ^ unknown identifier -"# +"#, ); } @@ -321,7 +321,7 @@ fn unknown_scalar_type() { │ = note: Valid scalar types are f16, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64, bool -"# +"#, ); } @@ -337,7 +337,7 @@ fn unknown_type() { 2 │ let a: Vec; │ ^^^ unknown type -"# +"#, ); } @@ -353,7 +353,7 @@ fn unknown_storage_format() { 2 │ let storage: [[access(read)]] texture_storage_1d; │ ^^^^ unknown storage format -"# +"#, ); } @@ -369,7 +369,7 @@ fn unknown_conservative_depth() { 2 │ [[early_depth_test(abc)]] fn main() {} │ ^^^ unknown conservative depth -"# +"#, ); } @@ -385,7 +385,7 @@ fn zero_array_stride() { 2 │ type zero = [[stride(0)]] array; │ ^ array stride must not be zero -"# +"#, ); } @@ -403,7 +403,7 @@ fn struct_member_zero_size() { 3 │ [[size(0)]] data: array; │ ^ struct member size or alignment must not be 0 -"# +"#, ); } @@ -421,7 +421,7 @@ fn struct_member_zero_align() { 3 │ [[align(0)]] data: array; │ ^ struct member size or alignment must not be 0 -"# +"#, ); } @@ -437,7 +437,7 @@ fn inconsistent_binding() { 2 │ fn foo([[builtin(vertex_index), location(0)]] x: u32) {} │ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ input/output binding is not consistent -"# +"#, ); }